Methods for Data Science: Course Work 2¶

In [1]:
import numpy as np
import random
import pandas as pd
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from matplotlib.colors import ListedColormap
import matplotlib.gridspec as gridspec
import scipy
from scipy.stats import multivariate_normal 
from scipy.sparse import linalg
from collections import defaultdict
from tqdm import tqdm
import networkx as nx

Before Reading:¶

$\cdot$ specification about chunk order:

The order of the chunk might not be continuent from the start to the end. However, the order within each section is continuent and interrupting chunk number are added print-puts based on running chunks.

$\cdot$ About the use of sklearn in 2.1:

There is an import of sklearn, which has been used in 2.1 k-mean clustering. This is only for purpose of result checking. All the code in this notebook for assessment is written by elementary python functions or allowed packages.

Enjoy!

Task 1: Neural Networks, Dimensionality Reduction and Mixture Models (65 marks)¶

1.1 Multi-Layer Perceptron (MLP) (25 marks)¶

In this section, an MLP is built from scratch. For given activation functions, SGD is used as the optimization function and KL divergence as the loss function. Then the learning rate is adjusted to find the optimal one, and then the layer width is adjusted to find its effect on the model performance. As MLP wraps thousands of millions of parameters, dropout is used to regularize the network and to contruct a 'sub-network'. And the effect of such regularization on model performance isdiscussed through training nad test losses and accuracies.

A probablistic substitute for MLP is deep Gaussian process. In the last bit of the seesion, the histogram of the outcome of the first layer are plotted for both drop-out case and non-dropout case. And the effect of dropout is discussed from this perspective.

In [2]:
# load data and do some data processing
MNIST_train = pd.read_csv("MNIST_train.csv")
MNIST_test = pd.read_csv("MNIST_test.csv")
display(MNIST_train.head())
print("Shape of MNIST_train: ", MNIST_train.shape)
print("Shape of MNIST_test: ", MNIST_test.shape)

# convert to numpy array
MNIST_train = MNIST_train.to_numpy()
MNIST_test = MNIST_test.to_numpy()

# target-predictor split
x_train, y_train = MNIST_train[:,1:]/255, MNIST_train[:,0]
x_test, y_test = MNIST_test[:,1:]/255, MNIST_test[:,0]
print("Shape of x_train: ", x_train.shape)
print("Shape of y_train: ", y_train.shape)
label 1x1 1x2 1x3 1x4 1x5 1x6 1x7 1x8 1x9 ... 28x19 28x20 28x21 28x22 28x23 28x24 28x25 28x26 28x27 28x28
0 4 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
1 9 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
2 7 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
3 8 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
4 2 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0

5 rows × 785 columns

Shape of MNIST_train:  (6000, 785)
Shape of MNIST_test:  (1000, 785)
Shape of x_train:  (6000, 784)
Shape of y_train:  (6000,)
In [3]:
MNIST_test[:, 1:][0]
Out[3]:
array([  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 242, 205,
        19,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  57, 242,
       253, 253, 166,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  13,
       241, 254, 253, 253, 173,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  56,
        95, 216, 253, 254, 253, 253, 241,  87,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
        15, 195, 253, 253, 253, 254, 253, 253, 253, 239,  88,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,  83, 253, 253, 253, 253, 254, 253, 253, 253, 253, 239,
        88,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,  68, 233, 253, 253, 253, 232, 214, 213, 213, 226,
       253, 253, 241,  86,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,  23, 206, 253, 253, 250, 148,  44,   0,   0,
         0,  30, 128, 246, 253, 243,  68,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0, 214, 253, 253, 253, 219,   0,   0,
         0,   0,   0,   0,   0, 110, 253, 253, 184,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,  50, 235, 253, 253, 232,  44,
         0,   0,   0,   0,   0,   0,   0,  26, 211, 253, 240,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0, 158, 254, 254, 254,
        80,   0,   0,   0,   0,   0,   0,   0,   0,  26, 213, 255, 241,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 103, 252, 253,
       225,  47,   4,   0,   0,   0,   0,   0,   0,   0,   0,  54, 253,
       253, 240,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 228,
       253, 253,  66,   0,   0,   0,   0,   0,   0,   0,   0,   0, 181,
       233, 253, 253, 120,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0, 241, 253, 228,  36,   0,   0,   0,   0,   0,   0,   0,   0,
        69, 245, 253, 253, 188,  17,   0,   0,   0,   0,   0,   0,   0,
         0,   0,  96, 251, 253,  67,   0,   0,   0,   0,   0,   0,   0,
         0, 111, 217, 253, 253, 243,  69,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0, 121, 253, 253,  68,   0,   0,   0,   0,   0,
         0,  54, 181, 247, 253, 253, 242,  98,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0, 102, 251, 253, 228,  58,   9,  41,
        55,  97, 112, 255, 253, 253, 253, 243,  98,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0, 227, 253, 253, 229,
       180, 253, 253, 253, 253, 255, 253, 243, 184,  69,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  45, 155,
       251, 253, 253, 253, 253, 253, 253, 241, 240,  81,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0, 102, 225, 253, 253, 253, 246, 120,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0], dtype=int64)
In [4]:
# input layer
def flatten(x):
    return np.array([i.flatten() for i in x])
In [5]:
# softplus function
def softplus(x, beta=1):
    sp = np.log(1 + np.exp(beta*x)) / beta
    return sp

# derivative of softplus
def softplus_deriv(x, beta=1):
    return np.exp(beta*x) / (1 + np.exp(beta*x))
In [6]:
# softmax function
def softmax(x):
    y = np.exp(x) / np.sum(np.exp(x), axis=1)[..., np.newaxis]
    return y
In [7]:
# KL divergence loss function
def kl_loss(p, q):
    # Replace all zeros with a very small float (considering underflow issues)
    p[p == 0] = 1e-10
    q[q == 0] = 1e-10
    # return np.sum(p * np.log(p / q))/len(q)
    return np.sum(p * (np.log(p + 1e-8) - np.log(q + 1e-8)), axis=1)
In [8]:
# accuracy
def mlp_accuracy(y_pred, y_test):
    return np.mean(y_pred == y_test)
In [9]:
# prob vector to one single prediction
def mlp_prediction(y_pred):
    return np.argmax(y_pred, axis=1)
In [10]:
# compute the output error
def output_error(y_batch, a):
    y_pred = softmax(a)
    return y_pred - y_batch
In [11]:
# used to transform labels to probability distribution
def one_hot(Y):
    return np.eye(10)[Y]
In [12]:
# considering overfitting problems, use regularization technique
def dropout_mask(n, dropout_prob):
    return np.random.binomial(1, (1 - dropout_prob), size=n) / (1 - dropout_prob)

initialize parameters

In [13]:
# initialize the parameters: 3 hidden layers, each with 200 neurons; output layer with 10 neurons, one for each class
# use Glorot initialisation to initialize weights and bias 
def init_params(width=200):
    var0 = 2. / (784 + width)
    W1 = np.random.randn(784, width) * np.sqrt(var0)
    b1 = np.zeros(width)

    var1 = 2. / (width + width)
    W2 = np.random.randn(width, width) * np.sqrt(var1)
    b2 = np.zeros(width)
    var2 = 2. / (width + width)
    W3 = np.random.randn(width, width) * np.sqrt(var2)
    b3 = np.zeros(width)

    var3 = 2. / (10 + width)
    W4 = np.random.randn(width, 10) * np.sqrt(var3)
    b4 = np.zeros(10)
    
    return {'W1':W1,'W2':W2,'W3':W3,'W4':W4}, {'b1':b1,'b2':b2,'b3':b3,'b4':b4}

forward propogation

In [14]:
def forward_prop(x, weights, bias, dropout_prob=0):
    W1, W2, W3, W4 = weights.values()
    b1,b2,b3,b4 = bias.values()
    # input
    Z1 = np.dot(x,W1) + b1
    d1 = dropout_mask(Z1.shape,dropout_prob)
    A1 = softplus(Z1)
    A1 *= d1
    
    # hidden
    Z2 = np.dot(A1,W2) + b2
    d2 = dropout_mask(Z2.shape,dropout_prob)
    A2 = softplus(Z2)
    A2 *= d2
    Z3 = np.dot(A2,W3) + b3
    d3 = dropout_mask(Z3.shape,dropout_prob)
    A3 = softplus(Z3)
    A3 *= d3
    
    # output
    Z4 = np.dot(A3,W4) + b4
    A4 = softmax(Z4)
    
    return {'Z1': Z1,'Z2': Z2,'Z3': Z3,'Z4': Z4 },{'A1': A1, 'A2': A2,'A3': A3,'A4': A4}, {'d1':d1,'d2':d2,'d3':d3}

backward propagation

In [15]:
def backward_prop(x_batch, y_batch, outputs, weights, bias):
    m = y_batch.shape[0]
    Z1,Z2,Z3,Z4 = outputs[0].values()
    A1,A2,A3,A4 = outputs[1].values()
    d1,d2,d3 = outputs[2].values()
    W1, W2, W3, W4 = weights.values()
    b1,b2,b3,b4 = bias.values()
    
    dZ4 = A4 - one_hot(y_batch)
    dW4 = 1 / m * np.dot(dZ4.T,A3)
    db4 = 1 / m * np.sum(dZ4)
    
    dZ3 = np.dot(dZ4,W4.T) * softplus_deriv(Z3)
    dZ3 *= d3
    dW3 = 1 / m * np.dot(dZ3.T,A2)
    db3 = 1 / m * np.sum(dZ3)
    
    dZ2 = np.dot(dZ3,W3.T)* softplus_deriv(Z2)
    dZ2 *= d2
    dW2 = 1 / m * np.dot(dZ2.T,A1)
    db2 = 1 / m * np.sum(dZ2)
    
    dZ1 = np.dot(dZ2,W2.T) * softplus_deriv(Z1)
    dZ1 *= d1
    dW1 = 1 / m * np.dot(dZ1.T,x_batch)
    db1 = 1 / m * np.sum(dZ1)
    
    return {'dW4': dW4, 'db4': db4, 'dW3': dW3, 'db3': db3, 'dW2': dW2, 'db2': db2, 'dW1': dW1, 'db1': db1}

update the parameters (SGD)

In [16]:
def update_params(weights, bias, grads, lr):
    weights['W1'] -= lr * grads['dW1'].T
    weights['W2'] -= lr * grads['dW2'].T
    weights['W3'] -= lr * grads['dW3'].T
    weights['W4'] -= lr * grads['dW4'].T
    
    bias['b1'] -= lr * grads['db1']
    bias['b2'] -= lr * grads['db2']
    bias['b3'] -= lr * grads['db3']
    bias['b4'] -= lr * grads['db4']
       
    return weights, bias

MLP training

In [17]:
def train(x_train, y_train, x_test, y_test, lr, epochs, batch_size, width=200, dropout_prob=0):
    # initialization
    weights, bias = init_params(width=width)
    training_losses = []
    training_accuracies = []
    test_losses = []
    test_accuracies = []
    
    # sgd implementation
    for epoch in range(1, epochs+1):
        # shuffle
        indices = np.random.permutation(len(x_train))
        x_train = x_train[indices]
        y_train = y_train[indices]
        
        n_obs = x_train.shape[0]
        # a process bar to track the process
        progress_bar = tqdm(range(0, n_obs, batch_size), desc = f"Epoch {epoch}", unit = "batch", colour='green')
        count = 0

        # iterate through batches 
        for start in progress_bar:
            stop = start + batch_size
            x_batch, y_batch = x_train[start:stop], y_train[start:stop]
            
            forward_outputs = forward_prop(x_batch, weights, bias,dropout_prob)
            grads = backward_prop(x_batch, y_batch, forward_outputs, weights, bias)
            weights, bias = update_params(weights, bias, grads, lr)
        
            count += 1
            if count == len(progress_bar):
                # forward propogation
                forward_outputs_train = forward_prop(x_train, weights, bias)
                predictions_train = mlp_prediction(forward_outputs_train[1]['A4'])
                training_accuracy = mlp_accuracy(predictions_train, y_train)
                training_loss = np.mean(kl_loss(one_hot(y_train), one_hot(predictions_train)))
                
                # forward propagation on test set
                forward_outputs_test = forward_prop(x_test, weights, bias)
                predictions_test = mlp_prediction(forward_outputs_test[1]['A4'])
                test_accuracy =  mlp_accuracy(predictions_test, y_test)
                test_loss = np.mean(kl_loss(one_hot(y_test), one_hot(predictions_test)))
                
                # store on the lists
                training_losses.append(training_loss)
                training_accuracies.append(training_accuracy)
                test_losses.append(test_loss)
                test_accuracies.append(test_accuracy)
                
                progress_bar.set_postfix(train_loss = training_loss, train_accuracy = training_accuracy, test_loss = test_loss, test_accuracy = test_accuracy)

    return weights, bias, training_losses, training_accuracies,test_losses,test_accuracies

1.1.1 Train MLP, with batch size of 128 and 40 epochs¶

In [18]:
lr = 0.18
epochs = 40
batch_size = 128

w,b, train_loss, train_acc, test_loss,test_acc = train(x_train, y_train, x_test, y_test, lr, epochs, batch_size)
Epoch 1: 100%|█| 47/47 [00:01<00:00, 36.78batch/s, test_accuracy=0.242, test_loss=14, train_accuracy=0.228, train_loss=
Epoch 2: 100%|█| 47/47 [00:01<00:00, 38.31batch/s, test_accuracy=0.588, test_loss=7.59, train_accuracy=0.596, train_los
Epoch 3: 100%|█| 47/47 [00:01<00:00, 40.14batch/s, test_accuracy=0.677, test_loss=5.95, train_accuracy=0.679, train_los
Epoch 4: 100%|█| 47/47 [00:01<00:00, 44.26batch/s, test_accuracy=0.82, test_loss=3.31, train_accuracy=0.815, train_loss
Epoch 5: 100%|█| 47/47 [00:01<00:00, 36.73batch/s, test_accuracy=0.831, test_loss=3.11, train_accuracy=0.823, train_los
Epoch 6: 100%|█| 47/47 [00:01<00:00, 40.75batch/s, test_accuracy=0.87, test_loss=2.39, train_accuracy=0.86, train_loss=
Epoch 7: 100%|█| 47/47 [00:01<00:00, 39.83batch/s, test_accuracy=0.872, test_loss=2.36, train_accuracy=0.868, train_los
Epoch 8: 100%|█| 47/47 [00:00<00:00, 47.74batch/s, test_accuracy=0.878, test_loss=2.25, train_accuracy=0.881, train_los
Epoch 9: 100%|█| 47/47 [00:00<00:00, 50.40batch/s, test_accuracy=0.909, test_loss=1.68, train_accuracy=0.901, train_los
Epoch 10: 100%|█| 47/47 [00:00<00:00, 49.27batch/s, test_accuracy=0.923, test_loss=1.42, train_accuracy=0.911, train_lo
Epoch 11: 100%|█| 47/47 [00:00<00:00, 51.39batch/s, test_accuracy=0.915, test_loss=1.56, train_accuracy=0.91, train_los
Epoch 12: 100%|█| 47/47 [00:00<00:00, 50.68batch/s, test_accuracy=0.917, test_loss=1.53, train_accuracy=0.926, train_lo
Epoch 13: 100%|█| 47/47 [00:01<00:00, 40.70batch/s, test_accuracy=0.905, test_loss=1.75, train_accuracy=0.906, train_lo
Epoch 14: 100%|█| 47/47 [00:01<00:00, 38.44batch/s, test_accuracy=0.881, test_loss=2.19, train_accuracy=0.879, train_lo
Epoch 15: 100%|█| 47/47 [00:01<00:00, 43.99batch/s, test_accuracy=0.918, test_loss=1.51, train_accuracy=0.922, train_lo
Epoch 16: 100%|█| 47/47 [00:01<00:00, 45.85batch/s, test_accuracy=0.935, test_loss=1.2, train_accuracy=0.939, train_los
Epoch 17: 100%|█| 47/47 [00:00<00:00, 50.67batch/s, test_accuracy=0.925, test_loss=1.38, train_accuracy=0.925, train_lo
Epoch 18: 100%|█| 47/47 [00:00<00:00, 52.62batch/s, test_accuracy=0.933, test_loss=1.23, train_accuracy=0.943, train_lo
Epoch 19: 100%|█| 47/47 [00:00<00:00, 51.45batch/s, test_accuracy=0.932, test_loss=1.25, train_accuracy=0.945, train_lo
Epoch 20: 100%|█| 47/47 [00:00<00:00, 50.95batch/s, test_accuracy=0.934, test_loss=1.22, train_accuracy=0.945, train_lo
Epoch 21: 100%|█| 47/47 [00:00<00:00, 50.54batch/s, test_accuracy=0.905, test_loss=1.75, train_accuracy=0.913, train_lo
Epoch 22: 100%|█| 47/47 [00:00<00:00, 50.55batch/s, test_accuracy=0.908, test_loss=1.69, train_accuracy=0.933, train_lo
Epoch 23: 100%|█| 47/47 [00:01<00:00, 40.33batch/s, test_accuracy=0.921, test_loss=1.45, train_accuracy=0.941, train_lo
Epoch 24: 100%|█| 47/47 [00:01<00:00, 44.73batch/s, test_accuracy=0.935, test_loss=1.2, train_accuracy=0.956, train_los
Epoch 25: 100%|█| 47/47 [00:01<00:00, 44.52batch/s, test_accuracy=0.938, test_loss=1.14, train_accuracy=0.955, train_lo
Epoch 26: 100%|█| 47/47 [00:01<00:00, 45.24batch/s, test_accuracy=0.943, test_loss=1.05, train_accuracy=0.96, train_los
Epoch 27: 100%|█| 47/47 [00:00<00:00, 47.70batch/s, test_accuracy=0.936, test_loss=1.18, train_accuracy=0.961, train_lo
Epoch 28: 100%|█| 47/47 [00:00<00:00, 48.78batch/s, test_accuracy=0.945, test_loss=1.01, train_accuracy=0.964, train_lo
Epoch 29: 100%|█| 47/47 [00:01<00:00, 43.35batch/s, test_accuracy=0.923, test_loss=1.42, train_accuracy=0.94, train_los
Epoch 30: 100%|█| 47/47 [00:00<00:00, 47.49batch/s, test_accuracy=0.943, test_loss=1.05, train_accuracy=0.968, train_lo
Epoch 31: 100%|█| 47/47 [00:01<00:00, 44.75batch/s, test_accuracy=0.924, test_loss=1.4, train_accuracy=0.962, train_los
Epoch 32: 100%|█| 47/47 [00:00<00:00, 47.62batch/s, test_accuracy=0.9, test_loss=1.84, train_accuracy=0.915, train_loss
Epoch 33: 100%|█| 47/47 [00:01<00:00, 38.86batch/s, test_accuracy=0.938, test_loss=1.14, train_accuracy=0.975, train_lo
Epoch 34: 100%|█| 47/47 [00:01<00:00, 45.31batch/s, test_accuracy=0.935, test_loss=1.2, train_accuracy=0.969, train_los
Epoch 35: 100%|█| 47/47 [00:01<00:00, 42.03batch/s, test_accuracy=0.943, test_loss=1.05, train_accuracy=0.976, train_lo
Epoch 36: 100%|█| 47/47 [00:01<00:00, 46.42batch/s, test_accuracy=0.931, test_loss=1.27, train_accuracy=0.97, train_los
Epoch 37: 100%|█| 47/47 [00:01<00:00, 37.34batch/s, test_accuracy=0.934, test_loss=1.22, train_accuracy=0.974, train_lo
Epoch 38: 100%|█| 47/47 [00:01<00:00, 44.50batch/s, test_accuracy=0.931, test_loss=1.27, train_accuracy=0.973, train_lo
Epoch 39: 100%|█| 47/47 [00:01<00:00, 42.19batch/s, test_accuracy=0.931, test_loss=1.27, train_accuracy=0.973, train_lo
Epoch 40: 100%|█| 47/47 [00:01<00:00, 45.57batch/s, test_accuracy=0.941, test_loss=1.09, train_accuracy=0.988, train_lo
In [19]:
# print train_loss
train_loss
Out[19]:
[14.210015430141066,
 7.441003545258493,
 5.9037075550834395,
 3.4151904931433825,
 3.258699284562688,
 2.577502258976138,
 2.427147960535863,
 2.193945375199927,
 1.8287992218449738,
 1.629349642281344,
 1.6631026480536508,
 1.365462506243311,
 1.7275402045280541,
 2.2307668360424424,
 1.4329685177879243,
 1.119986100626536,
 1.3808047815943594,
 1.0555485441521326,
 1.0217955383798258,
 1.0064532630287777,
 1.6048020017196665,
 1.2243135730136654,
 1.0954384600648586,
 0.816209048675777,
 0.8254144138864059,
 0.7425661269907444,
 0.7180184864290668,
 0.6627862951652925,
 1.1107807354159072,
 0.5830064633398406,
 0.7026762110780186,
 1.5679805408771506,
 0.46640517067187254,
 0.5768695531994212,
 0.4357206199697756,
 0.5492534575675341,
 0.4848159010931306,
 0.500158176444179,
 0.49095281123355,
 0.23013413026572657]
In [20]:
# set different learning rates
learning_rates = np.logspace(-4, 0, 7)
epochs = 40
batch_size = 128

train_losses = []
test_losses = []
for learning_rate in learning_rates:
    print(f'Training MLP for learning rate = {learning_rate}')
    w,b, train_loss, train_acc, test_loss,test_acc = train(x_train, y_train, x_test, y_test, learning_rate, epochs,batch_size)
    # record the final loss
    train_losses.append(train_loss[-1])
    test_losses.append(test_loss[-1])
Training MLP for learning rate = 0.0001
Epoch 1: 100%|█| 47/47 [00:01<00:00, 42.42batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=16
Epoch 2: 100%|█| 47/47 [00:01<00:00, 44.39batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=16
Epoch 3: 100%|█| 47/47 [00:01<00:00, 34.45batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=16
Epoch 4: 100%|█| 47/47 [00:03<00:00, 11.90batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=16
Epoch 5: 100%|█| 47/47 [00:02<00:00, 20.31batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=16
Epoch 6: 100%|█| 47/47 [00:01<00:00, 25.36batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=16
Epoch 7: 100%|█| 47/47 [00:02<00:00, 20.49batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=16
Epoch 8: 100%|█| 47/47 [00:02<00:00, 22.41batch/s, test_accuracy=0.098, test_loss=16.6, train_accuracy=0.0985, train_lo
Epoch 9: 100%|█| 47/47 [00:01<00:00, 30.30batch/s, test_accuracy=0.095, test_loss=16.7, train_accuracy=0.0968, train_lo
Epoch 10: 100%|█| 47/47 [00:04<00:00, 11.03batch/s, test_accuracy=0.081, test_loss=16.9, train_accuracy=0.089, train_lo
Epoch 11: 100%|█| 47/47 [00:02<00:00, 19.39batch/s, test_accuracy=0.072, test_loss=17.1, train_accuracy=0.08, train_los
Epoch 12: 100%|█| 47/47 [00:01<00:00, 24.93batch/s, test_accuracy=0.063, test_loss=17.3, train_accuracy=0.0705, train_l
Epoch 13: 100%|█| 47/47 [00:02<00:00, 19.62batch/s, test_accuracy=0.064, test_loss=17.2, train_accuracy=0.0697, train_l
Epoch 14: 100%|█| 47/47 [00:01<00:00, 26.84batch/s, test_accuracy=0.06, test_loss=17.3, train_accuracy=0.0688, train_lo
Epoch 15: 100%|█| 47/47 [00:01<00:00, 32.97batch/s, test_accuracy=0.063, test_loss=17.3, train_accuracy=0.0687, train_l
Epoch 16: 100%|█| 47/47 [00:01<00:00, 32.97batch/s, test_accuracy=0.064, test_loss=17.2, train_accuracy=0.071, train_lo
Epoch 17: 100%|█| 47/47 [00:01<00:00, 33.78batch/s, test_accuracy=0.07, test_loss=17.1, train_accuracy=0.0745, train_lo
Epoch 18: 100%|█| 47/47 [00:01<00:00, 34.64batch/s, test_accuracy=0.076, test_loss=17, train_accuracy=0.0775, train_los
Epoch 19: 100%|█| 47/47 [00:01<00:00, 33.60batch/s, test_accuracy=0.079, test_loss=17, train_accuracy=0.081, train_loss
Epoch 20: 100%|█| 47/47 [00:01<00:00, 33.93batch/s, test_accuracy=0.08, test_loss=16.9, train_accuracy=0.082, train_los
Epoch 21: 100%|█| 47/47 [00:01<00:00, 31.98batch/s, test_accuracy=0.085, test_loss=16.8, train_accuracy=0.0843, train_l
Epoch 22: 100%|█| 47/47 [00:01<00:00, 34.69batch/s, test_accuracy=0.091, test_loss=16.7, train_accuracy=0.0872, train_l
Epoch 23: 100%|█| 47/47 [00:01<00:00, 33.70batch/s, test_accuracy=0.094, test_loss=16.7, train_accuracy=0.0897, train_l
Epoch 24: 100%|█| 47/47 [00:01<00:00, 34.21batch/s, test_accuracy=0.097, test_loss=16.6, train_accuracy=0.0922, train_l
Epoch 25: 100%|█| 47/47 [00:01<00:00, 32.84batch/s, test_accuracy=0.099, test_loss=16.6, train_accuracy=0.0942, train_l
Epoch 26: 100%|█| 47/47 [00:01<00:00, 32.82batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.0962, train_los
Epoch 27: 100%|█| 47/47 [00:01<00:00, 33.49batch/s, test_accuracy=0.101, test_loss=16.6, train_accuracy=0.0973, train_l
Epoch 28: 100%|█| 47/47 [00:01<00:00, 32.70batch/s, test_accuracy=0.105, test_loss=16.5, train_accuracy=0.0982, train_l
Epoch 29: 100%|█| 47/47 [00:01<00:00, 28.34batch/s, test_accuracy=0.106, test_loss=16.5, train_accuracy=0.0998, train_l
Epoch 30: 100%|█| 47/47 [00:01<00:00, 27.95batch/s, test_accuracy=0.108, test_loss=16.4, train_accuracy=0.101, train_lo
Epoch 31: 100%|█| 47/47 [00:01<00:00, 33.67batch/s, test_accuracy=0.108, test_loss=16.4, train_accuracy=0.102, train_lo
Epoch 32: 100%|█| 47/47 [00:01<00:00, 36.49batch/s, test_accuracy=0.109, test_loss=16.4, train_accuracy=0.103, train_lo
Epoch 33: 100%|█| 47/47 [00:01<00:00, 33.64batch/s, test_accuracy=0.109, test_loss=16.4, train_accuracy=0.104, train_lo
Epoch 34: 100%|█| 47/47 [00:01<00:00, 30.15batch/s, test_accuracy=0.11, test_loss=16.4, train_accuracy=0.106, train_los
Epoch 35: 100%|█| 47/47 [00:01<00:00, 32.09batch/s, test_accuracy=0.112, test_loss=16.3, train_accuracy=0.105, train_lo
Epoch 36: 100%|█| 47/47 [00:01<00:00, 30.67batch/s, test_accuracy=0.111, test_loss=16.4, train_accuracy=0.105, train_lo
Epoch 37: 100%|█| 47/47 [00:01<00:00, 34.75batch/s, test_accuracy=0.11, test_loss=16.4, train_accuracy=0.106, train_los
Epoch 38: 100%|█| 47/47 [00:01<00:00, 33.78batch/s, test_accuracy=0.11, test_loss=16.4, train_accuracy=0.107, train_los
Epoch 39: 100%|█| 47/47 [00:01<00:00, 35.49batch/s, test_accuracy=0.11, test_loss=16.4, train_accuracy=0.107, train_los
Epoch 40: 100%|█| 47/47 [00:01<00:00, 34.25batch/s, test_accuracy=0.112, test_loss=16.3, train_accuracy=0.109, train_lo
Training MLP for learning rate = 0.00046415888336127773
Epoch 1: 100%|█| 47/47 [00:01<00:00, 35.19batch/s, test_accuracy=0.099, test_loss=16.6, train_accuracy=0.0983, train_lo
Epoch 2: 100%|█| 47/47 [00:01<00:00, 35.87batch/s, test_accuracy=0.094, test_loss=16.7, train_accuracy=0.0948, train_lo
Epoch 3: 100%|█| 47/47 [00:01<00:00, 35.33batch/s, test_accuracy=0.094, test_loss=16.7, train_accuracy=0.0918, train_lo
Epoch 4: 100%|█| 47/47 [00:01<00:00, 35.57batch/s, test_accuracy=0.106, test_loss=16.5, train_accuracy=0.105, train_los
Epoch 5: 100%|█| 47/47 [00:01<00:00, 35.37batch/s, test_accuracy=0.109, test_loss=16.4, train_accuracy=0.117, train_los
Epoch 6: 100%|█| 47/47 [00:01<00:00, 35.05batch/s, test_accuracy=0.123, test_loss=16.1, train_accuracy=0.122, train_los
Epoch 7: 100%|█| 47/47 [00:01<00:00, 37.60batch/s, test_accuracy=0.136, test_loss=15.9, train_accuracy=0.122, train_los
Epoch 8: 100%|█| 47/47 [00:01<00:00, 35.84batch/s, test_accuracy=0.136, test_loss=15.9, train_accuracy=0.132, train_los
Epoch 9: 100%|█| 47/47 [00:01<00:00, 33.78batch/s, test_accuracy=0.14, test_loss=15.8, train_accuracy=0.142, train_loss
Epoch 10: 100%|█| 47/47 [00:01<00:00, 35.64batch/s, test_accuracy=0.148, test_loss=15.7, train_accuracy=0.154, train_lo
Epoch 11: 100%|█| 47/47 [00:01<00:00, 35.24batch/s, test_accuracy=0.157, test_loss=15.5, train_accuracy=0.162, train_lo
Epoch 12: 100%|█| 47/47 [00:01<00:00, 35.00batch/s, test_accuracy=0.165, test_loss=15.4, train_accuracy=0.174, train_lo
Epoch 13: 100%|█| 47/47 [00:01<00:00, 36.69batch/s, test_accuracy=0.165, test_loss=15.4, train_accuracy=0.176, train_lo
Epoch 14: 100%|█| 47/47 [00:01<00:00, 36.05batch/s, test_accuracy=0.17, test_loss=15.3, train_accuracy=0.185, train_los
Epoch 15: 100%|█| 47/47 [00:01<00:00, 35.56batch/s, test_accuracy=0.172, test_loss=15.2, train_accuracy=0.192, train_lo
Epoch 16: 100%|█| 47/47 [00:01<00:00, 33.93batch/s, test_accuracy=0.183, test_loss=15, train_accuracy=0.199, train_loss
Epoch 17: 100%|█| 47/47 [00:01<00:00, 34.94batch/s, test_accuracy=0.186, test_loss=15, train_accuracy=0.204, train_loss
Epoch 18: 100%|█| 47/47 [00:01<00:00, 33.82batch/s, test_accuracy=0.198, test_loss=14.8, train_accuracy=0.213, train_lo
Epoch 19: 100%|█| 47/47 [00:01<00:00, 35.91batch/s, test_accuracy=0.203, test_loss=14.7, train_accuracy=0.219, train_lo
Epoch 20: 100%|█| 47/47 [00:01<00:00, 33.81batch/s, test_accuracy=0.207, test_loss=14.6, train_accuracy=0.227, train_lo
Epoch 21: 100%|█| 47/47 [00:01<00:00, 34.96batch/s, test_accuracy=0.22, test_loss=14.4, train_accuracy=0.233, train_los
Epoch 22: 100%|█| 47/47 [00:01<00:00, 36.36batch/s, test_accuracy=0.227, test_loss=14.2, train_accuracy=0.239, train_lo
Epoch 23: 100%|█| 47/47 [00:01<00:00, 44.30batch/s, test_accuracy=0.23, test_loss=14.2, train_accuracy=0.245, train_los
Epoch 24: 100%|█| 47/47 [00:01<00:00, 46.95batch/s, test_accuracy=0.24, test_loss=14, train_accuracy=0.253, train_loss=
Epoch 25: 100%|█| 47/47 [00:01<00:00, 43.90batch/s, test_accuracy=0.249, test_loss=13.8, train_accuracy=0.258, train_lo
Epoch 26: 100%|█| 47/47 [00:01<00:00, 45.71batch/s, test_accuracy=0.259, test_loss=13.6, train_accuracy=0.265, train_lo
Epoch 27: 100%|█| 47/47 [00:01<00:00, 46.80batch/s, test_accuracy=0.26, test_loss=13.6, train_accuracy=0.27, train_loss
Epoch 28: 100%|█| 47/47 [00:01<00:00, 45.00batch/s, test_accuracy=0.268, test_loss=13.5, train_accuracy=0.276, train_lo
Epoch 29: 100%|█| 47/47 [00:01<00:00, 44.27batch/s, test_accuracy=0.28, test_loss=13.3, train_accuracy=0.286, train_los
Epoch 30: 100%|█| 47/47 [00:01<00:00, 46.29batch/s, test_accuracy=0.288, test_loss=13.1, train_accuracy=0.298, train_lo
Epoch 31: 100%|█| 47/47 [00:00<00:00, 47.53batch/s, test_accuracy=0.291, test_loss=13.1, train_accuracy=0.302, train_lo
Epoch 32: 100%|█| 47/47 [00:00<00:00, 47.68batch/s, test_accuracy=0.303, test_loss=12.8, train_accuracy=0.316, train_lo
Epoch 33: 100%|█| 47/47 [00:01<00:00, 45.18batch/s, test_accuracy=0.302, test_loss=12.9, train_accuracy=0.318, train_lo
Epoch 34: 100%|█| 47/47 [00:00<00:00, 48.68batch/s, test_accuracy=0.313, test_loss=12.6, train_accuracy=0.325, train_lo
Epoch 35: 100%|█| 47/47 [00:00<00:00, 48.78batch/s, test_accuracy=0.328, test_loss=12.4, train_accuracy=0.333, train_lo
Epoch 36: 100%|█| 47/47 [00:01<00:00, 43.71batch/s, test_accuracy=0.326, test_loss=12.4, train_accuracy=0.342, train_lo
Epoch 37: 100%|█| 47/47 [00:01<00:00, 43.30batch/s, test_accuracy=0.337, test_loss=12.2, train_accuracy=0.346, train_lo
Epoch 38: 100%|█| 47/47 [00:01<00:00, 47.00batch/s, test_accuracy=0.343, test_loss=12.1, train_accuracy=0.354, train_lo
Epoch 39: 100%|█| 47/47 [00:01<00:00, 46.10batch/s, test_accuracy=0.348, test_loss=12, train_accuracy=0.364, train_loss
Epoch 40: 100%|█| 47/47 [00:01<00:00, 45.93batch/s, test_accuracy=0.353, test_loss=11.9, train_accuracy=0.367, train_lo
Training MLP for learning rate = 0.002154434690031882
Epoch 1: 100%|█| 47/47 [00:00<00:00, 47.78batch/s, test_accuracy=0.079, test_loss=17, train_accuracy=0.0878, train_loss
Epoch 2: 100%|█| 47/47 [00:00<00:00, 48.03batch/s, test_accuracy=0.134, test_loss=15.9, train_accuracy=0.125, train_los
Epoch 3: 100%|█| 47/47 [00:01<00:00, 46.49batch/s, test_accuracy=0.157, test_loss=15.5, train_accuracy=0.148, train_los
Epoch 4: 100%|█| 47/47 [00:00<00:00, 48.24batch/s, test_accuracy=0.229, test_loss=14.2, train_accuracy=0.223, train_los
Epoch 5: 100%|█| 47/47 [00:01<00:00, 45.99batch/s, test_accuracy=0.259, test_loss=13.6, train_accuracy=0.254, train_los
Epoch 6: 100%|█| 47/47 [00:01<00:00, 42.33batch/s, test_accuracy=0.285, test_loss=13.2, train_accuracy=0.287, train_los
Epoch 7: 100%|█| 47/47 [00:02<00:00, 16.35batch/s, test_accuracy=0.337, test_loss=12.2, train_accuracy=0.33, train_loss
Epoch 8: 100%|█| 47/47 [00:02<00:00, 21.64batch/s, test_accuracy=0.337, test_loss=12.2, train_accuracy=0.332, train_los
Epoch 9: 100%|█| 47/47 [00:01<00:00, 29.36batch/s, test_accuracy=0.358, test_loss=11.8, train_accuracy=0.34, train_loss
Epoch 10: 100%|█| 47/47 [00:01<00:00, 24.24batch/s, test_accuracy=0.405, test_loss=11, train_accuracy=0.409, train_loss
Epoch 11: 100%|█| 47/47 [00:01<00:00, 28.28batch/s, test_accuracy=0.427, test_loss=10.5, train_accuracy=0.403, train_lo
Epoch 12: 100%|█| 47/47 [00:01<00:00, 30.93batch/s, test_accuracy=0.478, test_loss=9.61, train_accuracy=0.448, train_lo
Epoch 13: 100%|█| 47/47 [00:01<00:00, 36.17batch/s, test_accuracy=0.54, test_loss=8.47, train_accuracy=0.516, train_los
Epoch 14: 100%|█| 47/47 [00:01<00:00, 35.90batch/s, test_accuracy=0.52, test_loss=8.84, train_accuracy=0.503, train_los
Epoch 15: 100%|█| 47/47 [00:01<00:00, 33.76batch/s, test_accuracy=0.502, test_loss=9.17, train_accuracy=0.491, train_lo
Epoch 16: 100%|█| 47/47 [00:01<00:00, 35.51batch/s, test_accuracy=0.571, test_loss=7.9, train_accuracy=0.558, train_los
Epoch 17: 100%|█| 47/47 [00:01<00:00, 35.52batch/s, test_accuracy=0.552, test_loss=8.25, train_accuracy=0.542, train_lo
Epoch 18: 100%|█| 47/47 [00:01<00:00, 34.92batch/s, test_accuracy=0.574, test_loss=7.84, train_accuracy=0.555, train_lo
Epoch 19: 100%|█| 47/47 [00:01<00:00, 32.98batch/s, test_accuracy=0.582, test_loss=7.7, train_accuracy=0.557, train_los
Epoch 20: 100%|█| 47/47 [00:01<00:00, 34.64batch/s, test_accuracy=0.615, test_loss=7.09, train_accuracy=0.587, train_lo
Epoch 21: 100%|█| 47/47 [00:01<00:00, 36.74batch/s, test_accuracy=0.619, test_loss=7.01, train_accuracy=0.587, train_lo
Epoch 22: 100%|█| 47/47 [00:01<00:00, 36.93batch/s, test_accuracy=0.607, test_loss=7.24, train_accuracy=0.592, train_lo
Epoch 23: 100%|█| 47/47 [00:01<00:00, 35.18batch/s, test_accuracy=0.664, test_loss=6.19, train_accuracy=0.625, train_lo
Epoch 24: 100%|█| 47/47 [00:01<00:00, 36.72batch/s, test_accuracy=0.619, test_loss=7.01, train_accuracy=0.593, train_lo
Epoch 25: 100%|█| 47/47 [00:01<00:00, 35.51batch/s, test_accuracy=0.682, test_loss=5.85, train_accuracy=0.664, train_lo
Epoch 26: 100%|█| 47/47 [00:01<00:00, 35.74batch/s, test_accuracy=0.679, test_loss=5.91, train_accuracy=0.645, train_lo
Epoch 27: 100%|█| 47/47 [00:01<00:00, 36.62batch/s, test_accuracy=0.668, test_loss=6.11, train_accuracy=0.639, train_lo
Epoch 28: 100%|█| 47/47 [00:01<00:00, 33.77batch/s, test_accuracy=0.632, test_loss=6.78, train_accuracy=0.607, train_lo
Epoch 29: 100%|█| 47/47 [00:01<00:00, 35.56batch/s, test_accuracy=0.689, test_loss=5.73, train_accuracy=0.667, train_lo
Epoch 30: 100%|█| 47/47 [00:01<00:00, 35.01batch/s, test_accuracy=0.703, test_loss=5.47, train_accuracy=0.676, train_lo
Epoch 31: 100%|█| 47/47 [00:01<00:00, 36.62batch/s, test_accuracy=0.687, test_loss=5.76, train_accuracy=0.662, train_lo
Epoch 32: 100%|█| 47/47 [00:01<00:00, 34.31batch/s, test_accuracy=0.721, test_loss=5.14, train_accuracy=0.705, train_lo
Epoch 33: 100%|█| 47/47 [00:01<00:00, 35.22batch/s, test_accuracy=0.72, test_loss=5.16, train_accuracy=0.689, train_los
Epoch 34: 100%|█| 47/47 [00:01<00:00, 35.39batch/s, test_accuracy=0.727, test_loss=5.03, train_accuracy=0.695, train_lo
Epoch 35: 100%|█| 47/47 [00:01<00:00, 32.02batch/s, test_accuracy=0.724, test_loss=5.08, train_accuracy=0.7, train_loss
Epoch 36: 100%|█| 47/47 [00:01<00:00, 36.53batch/s, test_accuracy=0.721, test_loss=5.14, train_accuracy=0.701, train_lo
Epoch 37: 100%|█| 47/47 [00:00<00:00, 47.51batch/s, test_accuracy=0.771, test_loss=4.22, train_accuracy=0.744, train_lo
Epoch 38: 100%|█| 47/47 [00:01<00:00, 40.25batch/s, test_accuracy=0.749, test_loss=4.62, train_accuracy=0.727, train_lo
Epoch 39: 100%|█| 47/47 [00:01<00:00, 39.19batch/s, test_accuracy=0.754, test_loss=4.53, train_accuracy=0.724, train_lo
Epoch 40: 100%|█| 47/47 [00:01<00:00, 43.86batch/s, test_accuracy=0.773, test_loss=4.18, train_accuracy=0.75, train_los
Training MLP for learning rate = 0.01
Epoch 1: 100%|█| 47/47 [00:00<00:00, 50.50batch/s, test_accuracy=0.194, test_loss=14.8, train_accuracy=0.193, train_los
Epoch 2: 100%|█| 47/47 [00:00<00:00, 48.89batch/s, test_accuracy=0.179, test_loss=15.1, train_accuracy=0.186, train_los
Epoch 3: 100%|█| 47/47 [00:01<00:00, 38.61batch/s, test_accuracy=0.433, test_loss=10.4, train_accuracy=0.428, train_los
Epoch 4: 100%|█| 47/47 [00:01<00:00, 45.37batch/s, test_accuracy=0.474, test_loss=9.68, train_accuracy=0.485, train_los
Epoch 5: 100%|█| 47/47 [00:01<00:00, 45.85batch/s, test_accuracy=0.546, test_loss=8.36, train_accuracy=0.536, train_los
Epoch 6: 100%|█| 47/47 [00:00<00:00, 48.28batch/s, test_accuracy=0.562, test_loss=8.06, train_accuracy=0.568, train_los
Epoch 7: 100%|█| 47/47 [00:00<00:00, 49.49batch/s, test_accuracy=0.559, test_loss=8.12, train_accuracy=0.538, train_los
Epoch 8: 100%|█| 47/47 [00:00<00:00, 47.27batch/s, test_accuracy=0.631, test_loss=6.79, train_accuracy=0.636, train_los
Epoch 9: 100%|█| 47/47 [00:00<00:00, 48.24batch/s, test_accuracy=0.687, test_loss=5.76, train_accuracy=0.681, train_los
Epoch 10: 100%|█| 47/47 [00:00<00:00, 48.81batch/s, test_accuracy=0.75, test_loss=4.6, train_accuracy=0.746, train_loss
Epoch 11: 100%|█| 47/47 [00:00<00:00, 50.21batch/s, test_accuracy=0.738, test_loss=4.82, train_accuracy=0.725, train_lo
Epoch 12: 100%|█| 47/47 [00:00<00:00, 48.94batch/s, test_accuracy=0.733, test_loss=4.92, train_accuracy=0.733, train_lo
Epoch 13: 100%|█| 47/47 [00:00<00:00, 50.79batch/s, test_accuracy=0.748, test_loss=4.64, train_accuracy=0.748, train_lo
Epoch 14: 100%|█| 47/47 [00:00<00:00, 49.54batch/s, test_accuracy=0.802, test_loss=3.65, train_accuracy=0.79, train_los
Epoch 15: 100%|█| 47/47 [00:00<00:00, 48.62batch/s, test_accuracy=0.784, test_loss=3.98, train_accuracy=0.785, train_lo
Epoch 16: 100%|█| 47/47 [00:00<00:00, 48.33batch/s, test_accuracy=0.804, test_loss=3.61, train_accuracy=0.786, train_lo
Epoch 17: 100%|█| 47/47 [00:00<00:00, 49.12batch/s, test_accuracy=0.825, test_loss=3.22, train_accuracy=0.823, train_lo
Epoch 18: 100%|█| 47/47 [00:00<00:00, 50.56batch/s, test_accuracy=0.788, test_loss=3.9, train_accuracy=0.791, train_los
Epoch 19: 100%|█| 47/47 [00:00<00:00, 49.71batch/s, test_accuracy=0.857, test_loss=2.63, train_accuracy=0.838, train_lo
Epoch 20: 100%|█| 47/47 [00:00<00:00, 49.50batch/s, test_accuracy=0.866, test_loss=2.47, train_accuracy=0.852, train_lo
Epoch 21: 100%|█| 47/47 [00:00<00:00, 47.82batch/s, test_accuracy=0.809, test_loss=3.52, train_accuracy=0.805, train_lo
Epoch 22: 100%|█| 47/47 [00:01<00:00, 44.21batch/s, test_accuracy=0.783, test_loss=4, train_accuracy=0.782, train_loss=
Epoch 23: 100%|█| 47/47 [00:00<00:00, 48.62batch/s, test_accuracy=0.864, test_loss=2.5, train_accuracy=0.858, train_los
Epoch 24: 100%|█| 47/47 [00:01<00:00, 46.89batch/s, test_accuracy=0.864, test_loss=2.5, train_accuracy=0.852, train_los
Epoch 25: 100%|█| 47/47 [00:00<00:00, 48.58batch/s, test_accuracy=0.872, test_loss=2.36, train_accuracy=0.867, train_lo
Epoch 26: 100%|█| 47/47 [00:00<00:00, 48.40batch/s, test_accuracy=0.882, test_loss=2.17, train_accuracy=0.871, train_lo
Epoch 27: 100%|█| 47/47 [00:00<00:00, 50.29batch/s, test_accuracy=0.865, test_loss=2.49, train_accuracy=0.86, train_los
Epoch 28: 100%|█| 47/47 [00:00<00:00, 48.77batch/s, test_accuracy=0.876, test_loss=2.28, train_accuracy=0.868, train_lo
Epoch 29: 100%|█| 47/47 [00:00<00:00, 48.04batch/s, test_accuracy=0.86, test_loss=2.58, train_accuracy=0.861, train_los
Epoch 30: 100%|█| 47/47 [00:00<00:00, 48.77batch/s, test_accuracy=0.887, test_loss=2.08, train_accuracy=0.884, train_lo
Epoch 31: 100%|█| 47/47 [00:00<00:00, 49.22batch/s, test_accuracy=0.889, test_loss=2.04, train_accuracy=0.883, train_lo
Epoch 32: 100%|█| 47/47 [00:01<00:00, 45.53batch/s, test_accuracy=0.883, test_loss=2.15, train_accuracy=0.862, train_lo
Epoch 33: 100%|█| 47/47 [00:00<00:00, 48.42batch/s, test_accuracy=0.897, test_loss=1.9, train_accuracy=0.886, train_los
Epoch 34: 100%|█| 47/47 [00:00<00:00, 51.62batch/s, test_accuracy=0.898, test_loss=1.88, train_accuracy=0.884, train_lo
Epoch 35: 100%|█| 47/47 [00:00<00:00, 50.73batch/s, test_accuracy=0.898, test_loss=1.88, train_accuracy=0.89, train_los
Epoch 36: 100%|█| 47/47 [00:00<00:00, 47.89batch/s, test_accuracy=0.897, test_loss=1.9, train_accuracy=0.889, train_los
Epoch 37: 100%|█| 47/47 [00:00<00:00, 48.14batch/s, test_accuracy=0.895, test_loss=1.93, train_accuracy=0.882, train_lo
Epoch 38: 100%|█| 47/47 [00:00<00:00, 48.27batch/s, test_accuracy=0.887, test_loss=2.08, train_accuracy=0.88, train_los
Epoch 39: 100%|█| 47/47 [00:01<00:00, 46.83batch/s, test_accuracy=0.901, test_loss=1.82, train_accuracy=0.895, train_lo
Epoch 40: 100%|█| 47/47 [00:00<00:00, 48.33batch/s, test_accuracy=0.901, test_loss=1.82, train_accuracy=0.896, train_lo
Training MLP for learning rate = 0.046415888336127774
Epoch 1: 100%|█| 47/47 [00:00<00:00, 50.65batch/s, test_accuracy=0.152, test_loss=15.6, train_accuracy=0.151, train_los
Epoch 2: 100%|█| 47/47 [00:00<00:00, 49.92batch/s, test_accuracy=0.337, test_loss=12.2, train_accuracy=0.33, train_loss
Epoch 3: 100%|█| 47/47 [00:00<00:00, 51.43batch/s, test_accuracy=0.618, test_loss=7.03, train_accuracy=0.62, train_loss
Epoch 4: 100%|█| 47/47 [00:00<00:00, 49.22batch/s, test_accuracy=0.578, test_loss=7.77, train_accuracy=0.578, train_los
Epoch 5: 100%|█| 47/47 [00:00<00:00, 50.35batch/s, test_accuracy=0.796, test_loss=3.76, train_accuracy=0.789, train_los
Epoch 6: 100%|█| 47/47 [00:00<00:00, 49.76batch/s, test_accuracy=0.711, test_loss=5.32, train_accuracy=0.71, train_loss
Epoch 7: 100%|█| 47/47 [00:00<00:00, 51.60batch/s, test_accuracy=0.791, test_loss=3.85, train_accuracy=0.8, train_loss=
Epoch 8: 100%|█| 47/47 [00:00<00:00, 50.46batch/s, test_accuracy=0.864, test_loss=2.5, train_accuracy=0.864, train_loss
Epoch 9: 100%|█| 47/47 [00:00<00:00, 51.83batch/s, test_accuracy=0.881, test_loss=2.19, train_accuracy=0.877, train_los
Epoch 10: 100%|█| 47/47 [00:00<00:00, 48.38batch/s, test_accuracy=0.856, test_loss=2.65, train_accuracy=0.858, train_lo
Epoch 11: 100%|█| 47/47 [00:01<00:00, 39.87batch/s, test_accuracy=0.868, test_loss=2.43, train_accuracy=0.871, train_lo
Epoch 12: 100%|█| 47/47 [00:01<00:00, 46.71batch/s, test_accuracy=0.886, test_loss=2.1, train_accuracy=0.875, train_los
Epoch 13: 100%|█| 47/47 [00:00<00:00, 47.94batch/s, test_accuracy=0.901, test_loss=1.82, train_accuracy=0.903, train_lo
Epoch 14: 100%|█| 47/47 [00:00<00:00, 47.69batch/s, test_accuracy=0.877, test_loss=2.26, train_accuracy=0.876, train_lo
Epoch 15: 100%|█| 47/47 [00:00<00:00, 48.98batch/s, test_accuracy=0.893, test_loss=1.97, train_accuracy=0.887, train_lo
Epoch 16: 100%|█| 47/47 [00:00<00:00, 48.53batch/s, test_accuracy=0.907, test_loss=1.71, train_accuracy=0.906, train_lo
Epoch 17: 100%|█| 47/47 [00:00<00:00, 48.01batch/s, test_accuracy=0.882, test_loss=2.17, train_accuracy=0.871, train_lo
Epoch 18: 100%|█| 47/47 [00:01<00:00, 44.37batch/s, test_accuracy=0.903, test_loss=1.79, train_accuracy=0.898, train_lo
Epoch 19: 100%|█| 47/47 [00:00<00:00, 48.46batch/s, test_accuracy=0.9, test_loss=1.84, train_accuracy=0.903, train_loss
Epoch 20: 100%|█| 47/47 [00:01<00:00, 44.71batch/s, test_accuracy=0.911, test_loss=1.64, train_accuracy=0.909, train_lo
Epoch 21: 100%|█| 47/47 [00:00<00:00, 47.01batch/s, test_accuracy=0.856, test_loss=2.65, train_accuracy=0.865, train_lo
Epoch 22: 100%|█| 47/47 [00:00<00:00, 49.48batch/s, test_accuracy=0.91, test_loss=1.66, train_accuracy=0.917, train_los
Epoch 23: 100%|█| 47/47 [00:00<00:00, 47.92batch/s, test_accuracy=0.902, test_loss=1.8, train_accuracy=0.904, train_los
Epoch 24: 100%|█| 47/47 [00:00<00:00, 50.19batch/s, test_accuracy=0.918, test_loss=1.51, train_accuracy=0.916, train_lo
Epoch 25: 100%|█| 47/47 [00:00<00:00, 49.77batch/s, test_accuracy=0.906, test_loss=1.73, train_accuracy=0.908, train_lo
Epoch 26: 100%|█| 47/47 [00:00<00:00, 48.60batch/s, test_accuracy=0.919, test_loss=1.49, train_accuracy=0.919, train_lo
Epoch 27: 100%|█| 47/47 [00:00<00:00, 49.87batch/s, test_accuracy=0.9, test_loss=1.84, train_accuracy=0.905, train_loss
Epoch 28: 100%|█| 47/47 [00:00<00:00, 49.64batch/s, test_accuracy=0.915, test_loss=1.56, train_accuracy=0.917, train_lo
Epoch 29: 100%|█| 47/47 [00:01<00:00, 45.16batch/s, test_accuracy=0.912, test_loss=1.62, train_accuracy=0.922, train_lo
Epoch 30: 100%|█| 47/47 [00:00<00:00, 50.30batch/s, test_accuracy=0.907, test_loss=1.71, train_accuracy=0.919, train_lo
Epoch 31: 100%|█| 47/47 [00:01<00:00, 43.81batch/s, test_accuracy=0.926, test_loss=1.36, train_accuracy=0.932, train_lo
Epoch 32: 100%|█| 47/47 [00:02<00:00, 19.07batch/s, test_accuracy=0.904, test_loss=1.77, train_accuracy=0.91, train_los
Epoch 33: 100%|█| 47/47 [00:03<00:00, 13.59batch/s, test_accuracy=0.919, test_loss=1.49, train_accuracy=0.928, train_lo
Epoch 34: 100%|█| 47/47 [00:01<00:00, 25.69batch/s, test_accuracy=0.918, test_loss=1.51, train_accuracy=0.928, train_lo
Epoch 35: 100%|█| 47/47 [00:01<00:00, 27.09batch/s, test_accuracy=0.917, test_loss=1.53, train_accuracy=0.933, train_lo
Epoch 36: 100%|█| 47/47 [00:01<00:00, 28.26batch/s, test_accuracy=0.921, test_loss=1.45, train_accuracy=0.929, train_lo
Epoch 37: 100%|█| 47/47 [00:01<00:00, 37.08batch/s, test_accuracy=0.932, test_loss=1.25, train_accuracy=0.937, train_lo
Epoch 38: 100%|█| 47/47 [00:01<00:00, 35.14batch/s, test_accuracy=0.922, test_loss=1.44, train_accuracy=0.933, train_lo
Epoch 39: 100%|█| 47/47 [00:01<00:00, 33.26batch/s, test_accuracy=0.928, test_loss=1.33, train_accuracy=0.939, train_lo
Epoch 40: 100%|█| 47/47 [00:01<00:00, 26.46batch/s, test_accuracy=0.936, test_loss=1.18, train_accuracy=0.942, train_lo
Training MLP for learning rate = 0.21544346900318823
Epoch 1: 100%|█| 47/47 [00:02<00:00, 21.76batch/s, test_accuracy=0.195, test_loss=14.8, train_accuracy=0.192, train_los
Epoch 2: 100%|█| 47/47 [00:01<00:00, 27.31batch/s, test_accuracy=0.415, test_loss=10.8, train_accuracy=0.406, train_los
Epoch 3: 100%|█| 47/47 [00:01<00:00, 30.72batch/s, test_accuracy=0.659, test_loss=6.28, train_accuracy=0.652, train_los
Epoch 4: 100%|█| 47/47 [00:01<00:00, 31.26batch/s, test_accuracy=0.759, test_loss=4.44, train_accuracy=0.765, train_los
Epoch 5: 100%|█| 47/47 [00:01<00:00, 33.99batch/s, test_accuracy=0.813, test_loss=3.44, train_accuracy=0.806, train_los
Epoch 6: 100%|█| 47/47 [00:01<00:00, 34.58batch/s, test_accuracy=0.855, test_loss=2.67, train_accuracy=0.864, train_los
Epoch 7: 100%|█| 47/47 [00:01<00:00, 34.80batch/s, test_accuracy=0.876, test_loss=2.28, train_accuracy=0.881, train_los
Epoch 8: 100%|█| 47/47 [00:01<00:00, 34.06batch/s, test_accuracy=0.855, test_loss=2.67, train_accuracy=0.853, train_los
Epoch 9: 100%|█| 47/47 [00:01<00:00, 29.78batch/s, test_accuracy=0.9, test_loss=1.84, train_accuracy=0.898, train_loss=
Epoch 10: 100%|█| 47/47 [00:01<00:00, 32.49batch/s, test_accuracy=0.89, test_loss=2.03, train_accuracy=0.899, train_los
Epoch 11: 100%|█| 47/47 [00:01<00:00, 32.83batch/s, test_accuracy=0.9, test_loss=1.84, train_accuracy=0.894, train_loss
Epoch 12: 100%|█| 47/47 [00:01<00:00, 35.07batch/s, test_accuracy=0.874, test_loss=2.32, train_accuracy=0.881, train_lo
Epoch 13: 100%|█| 47/47 [00:01<00:00, 33.77batch/s, test_accuracy=0.919, test_loss=1.49, train_accuracy=0.927, train_lo
Epoch 14: 100%|█| 47/47 [00:01<00:00, 35.80batch/s, test_accuracy=0.93, test_loss=1.29, train_accuracy=0.934, train_los
Epoch 15: 100%|█| 47/47 [00:01<00:00, 35.50batch/s, test_accuracy=0.918, test_loss=1.51, train_accuracy=0.928, train_lo
Epoch 16: 100%|█| 47/47 [00:01<00:00, 36.19batch/s, test_accuracy=0.925, test_loss=1.38, train_accuracy=0.931, train_lo
Epoch 17: 100%|█| 47/47 [00:01<00:00, 34.07batch/s, test_accuracy=0.933, test_loss=1.23, train_accuracy=0.944, train_lo
Epoch 18: 100%|█| 47/47 [00:01<00:00, 35.74batch/s, test_accuracy=0.934, test_loss=1.22, train_accuracy=0.945, train_lo
Epoch 19: 100%|█| 47/47 [00:01<00:00, 35.86batch/s, test_accuracy=0.934, test_loss=1.22, train_accuracy=0.947, train_lo
Epoch 20: 100%|█| 47/47 [00:01<00:00, 35.22batch/s, test_accuracy=0.936, test_loss=1.18, train_accuracy=0.947, train_lo
Epoch 21: 100%|█| 47/47 [00:01<00:00, 33.57batch/s, test_accuracy=0.935, test_loss=1.2, train_accuracy=0.955, train_los
Epoch 22: 100%|█| 47/47 [00:01<00:00, 38.66batch/s, test_accuracy=0.936, test_loss=1.18, train_accuracy=0.948, train_lo
Epoch 23: 100%|█| 47/47 [00:01<00:00, 45.04batch/s, test_accuracy=0.931, test_loss=1.27, train_accuracy=0.954, train_lo
Epoch 24: 100%|█| 47/47 [00:01<00:00, 46.61batch/s, test_accuracy=0.92, test_loss=1.47, train_accuracy=0.936, train_los
Epoch 25: 100%|█| 47/47 [00:01<00:00, 39.51batch/s, test_accuracy=0.925, test_loss=1.38, train_accuracy=0.937, train_lo
Epoch 26: 100%|█| 47/47 [00:01<00:00, 45.32batch/s, test_accuracy=0.924, test_loss=1.4, train_accuracy=0.948, train_los
Epoch 27: 100%|█| 47/47 [00:01<00:00, 45.22batch/s, test_accuracy=0.946, test_loss=0.994, train_accuracy=0.964, train_l
Epoch 28: 100%|█| 47/47 [00:01<00:00, 44.73batch/s, test_accuracy=0.938, test_loss=1.14, train_accuracy=0.959, train_lo
Epoch 29: 100%|█| 47/47 [00:01<00:00, 45.66batch/s, test_accuracy=0.939, test_loss=1.12, train_accuracy=0.96, train_los
Epoch 30: 100%|█| 47/47 [00:01<00:00, 42.64batch/s, test_accuracy=0.949, test_loss=0.939, train_accuracy=0.977, train_l
Epoch 31: 100%|█| 47/47 [00:01<00:00, 44.88batch/s, test_accuracy=0.929, test_loss=1.31, train_accuracy=0.95, train_los
Epoch 32: 100%|█| 47/47 [00:01<00:00, 45.53batch/s, test_accuracy=0.933, test_loss=1.23, train_accuracy=0.971, train_lo
Epoch 33: 100%|█| 47/47 [00:01<00:00, 45.36batch/s, test_accuracy=0.938, test_loss=1.14, train_accuracy=0.973, train_lo
Epoch 34: 100%|█| 47/47 [00:01<00:00, 44.21batch/s, test_accuracy=0.941, test_loss=1.09, train_accuracy=0.973, train_lo
Epoch 35: 100%|█| 47/47 [00:01<00:00, 42.73batch/s, test_accuracy=0.944, test_loss=1.03, train_accuracy=0.981, train_lo
Epoch 36: 100%|█| 47/47 [00:01<00:00, 43.72batch/s, test_accuracy=0.936, test_loss=1.18, train_accuracy=0.975, train_lo
Epoch 37: 100%|█| 47/47 [00:01<00:00, 41.21batch/s, test_accuracy=0.932, test_loss=1.25, train_accuracy=0.978, train_lo
Epoch 38: 100%|█| 47/47 [00:01<00:00, 44.30batch/s, test_accuracy=0.929, test_loss=1.31, train_accuracy=0.968, train_lo
Epoch 39: 100%|█| 47/47 [00:01<00:00, 44.29batch/s, test_accuracy=0.58, test_loss=7.73, train_accuracy=0.596, train_los
Epoch 40: 100%|█| 47/47 [00:01<00:00, 46.87batch/s, test_accuracy=0.94, test_loss=1.1, train_accuracy=0.979, train_loss
Training MLP for learning rate = 1.0
Epoch 1: 100%|█| 47/47 [00:00<00:00, 48.34batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=16
Epoch 2: 100%|█| 47/47 [00:00<00:00, 48.87batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=16
Epoch 3: 100%|█| 47/47 [00:01<00:00, 44.84batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=16
Epoch 4: 100%|█| 47/47 [00:00<00:00, 47.70batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=16
Epoch 5: 100%|█| 47/47 [00:00<00:00, 49.57batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=16
Epoch 6: 100%|█| 47/47 [00:00<00:00, 48.61batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=16
Epoch 7: 100%|█| 47/47 [00:00<00:00, 47.60batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=16
Epoch 8: 100%|█| 47/47 [00:01<00:00, 46.72batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=16
Epoch 9: 100%|█| 47/47 [00:00<00:00, 47.70batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=16
Epoch 10: 100%|█| 47/47 [00:01<00:00, 45.53batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1
Epoch 11: 100%|█| 47/47 [00:00<00:00, 47.22batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1
Epoch 12: 100%|█| 47/47 [00:00<00:00, 48.88batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1
Epoch 13: 100%|█| 47/47 [00:00<00:00, 47.05batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1
Epoch 14: 100%|█| 47/47 [00:00<00:00, 47.90batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1
Epoch 15: 100%|█| 47/47 [00:00<00:00, 47.84batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1
Epoch 16: 100%|█| 47/47 [00:01<00:00, 46.27batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1
Epoch 17: 100%|█| 47/47 [00:00<00:00, 47.17batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1
Epoch 18: 100%|█| 47/47 [00:00<00:00, 48.37batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1
Epoch 19: 100%|█| 47/47 [00:01<00:00, 46.61batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1
Epoch 20: 100%|█| 47/47 [00:01<00:00, 44.13batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1
Epoch 21: 100%|█| 47/47 [00:00<00:00, 49.69batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1
Epoch 22: 100%|█| 47/47 [00:00<00:00, 48.09batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1
Epoch 23: 100%|█| 47/47 [00:01<00:00, 43.99batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1
Epoch 24: 100%|█| 47/47 [00:01<00:00, 43.79batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1
Epoch 25: 100%|█| 47/47 [00:01<00:00, 46.11batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1
Epoch 26: 100%|█| 47/47 [00:01<00:00, 46.49batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1
Epoch 27: 100%|█| 47/47 [00:00<00:00, 48.56batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1
Epoch 28: 100%|█| 47/47 [00:01<00:00, 46.84batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1
Epoch 29: 100%|█| 47/47 [00:00<00:00, 48.07batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1
Epoch 30: 100%|█| 47/47 [00:00<00:00, 47.70batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1
Epoch 31: 100%|█| 47/47 [00:01<00:00, 44.04batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1
Epoch 32: 100%|█| 47/47 [00:00<00:00, 47.35batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1
Epoch 33: 100%|█| 47/47 [00:01<00:00, 42.90batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1
Epoch 34: 100%|█| 47/47 [00:01<00:00, 45.96batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1
Epoch 35: 100%|█| 47/47 [00:00<00:00, 47.03batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1
Epoch 36: 100%|█| 47/47 [00:01<00:00, 42.50batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1
Epoch 37: 100%|█| 47/47 [00:00<00:00, 47.70batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1
Epoch 38: 100%|█| 47/47 [00:01<00:00, 41.04batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1
Epoch 39: 100%|█| 47/47 [00:01<00:00, 45.27batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1
Epoch 40: 100%|█| 47/47 [00:01<00:00, 43.53batch/s, test_accuracy=0.1, test_loss=16.6, train_accuracy=0.1, train_loss=1
In [21]:
# plotting
plt.plot(learning_rates, train_losses,'.-',label='Train loss')
plt.plot(learning_rates, test_losses,'.-',label='Test loss')

plt.xscale('log')
plt.xlabel('Learning rate')
plt.ylabel('Final loss')
plt.title('Final Losses against Learning rates')
plt.legend()
plt.show()

Explanation:¶

In finding the relations between the losses and the learning rate, options for learning rates from $10^{-4}$ to $10^{0}$ are split into 7 points at equal distance.

From the plot, the final loss at epoch=40 for both train and test data have a decreasing trend, meaning that a too small learning rate ($lr$) is getting close but not enough to converge to the optimal point.

However, the loss explodes in the final bit, meaning that a big step makes the final loss sway unstablely around the optimal point and hardly converges to the optimal point.

Optimal learning rate should be set at a value where both the train loss is minimal in the restricted interval and right before the test loss is exploding.

Note that this doesn't imply any information about overfitting or underfitting.

1.1.2 Train the MLP at optimal learning rate found in 1.1.1. Use loss and accuracy to be the model measure.¶

In [22]:
best_lr = learning_rates[-2]
print("Best learning rate:", best_lr)

epochs = 40
batch_size = 128

best_w,best_b, best_train_loss, best_train_acc, best_test_loss, best_test_acc = train(x_train, y_train, x_test, y_test, best_lr, epochs,batch_size)
Best learning rate: 0.21544346900318823
Epoch 1: 100%|█| 47/47 [00:01<00:00, 46.95batch/s, test_accuracy=0.121, test_loss=16.2, train_accuracy=0.13, train_loss
Epoch 2: 100%|█| 47/47 [00:01<00:00, 43.55batch/s, test_accuracy=0.359, test_loss=11.8, train_accuracy=0.356, train_los
Epoch 3: 100%|█| 47/47 [00:01<00:00, 45.47batch/s, test_accuracy=0.545, test_loss=8.38, train_accuracy=0.536, train_los
Epoch 4: 100%|█| 47/47 [00:01<00:00, 46.50batch/s, test_accuracy=0.602, test_loss=7.33, train_accuracy=0.584, train_los
Epoch 5: 100%|█| 47/47 [00:01<00:00, 44.83batch/s, test_accuracy=0.719, test_loss=5.17, train_accuracy=0.723, train_los
Epoch 6: 100%|█| 47/47 [00:01<00:00, 46.84batch/s, test_accuracy=0.783, test_loss=4, train_accuracy=0.777, train_loss=4
Epoch 7: 100%|█| 47/47 [00:01<00:00, 46.47batch/s, test_accuracy=0.854, test_loss=2.69, train_accuracy=0.859, train_los
Epoch 8: 100%|█| 47/47 [00:00<00:00, 47.17batch/s, test_accuracy=0.858, test_loss=2.61, train_accuracy=0.865, train_los
Epoch 9: 100%|█| 47/47 [00:01<00:00, 30.05batch/s, test_accuracy=0.887, test_loss=2.08, train_accuracy=0.893, train_los
Epoch 10: 100%|█| 47/47 [00:03<00:00, 14.71batch/s, test_accuracy=0.888, test_loss=2.06, train_accuracy=0.888, train_lo
Epoch 11: 100%|█| 47/47 [00:01<00:00, 25.39batch/s, test_accuracy=0.89, test_loss=2.03, train_accuracy=0.894, train_los
Epoch 12: 100%|█| 47/47 [00:01<00:00, 27.07batch/s, test_accuracy=0.903, test_loss=1.79, train_accuracy=0.905, train_lo
Epoch 13: 100%|█| 47/47 [00:01<00:00, 25.46batch/s, test_accuracy=0.902, test_loss=1.8, train_accuracy=0.911, train_los
Epoch 14: 100%|█| 47/47 [00:01<00:00, 27.23batch/s, test_accuracy=0.903, test_loss=1.79, train_accuracy=0.916, train_lo
Epoch 15: 100%|█| 47/47 [00:01<00:00, 34.49batch/s, test_accuracy=0.904, test_loss=1.77, train_accuracy=0.911, train_lo
Epoch 16: 100%|█| 47/47 [00:01<00:00, 34.07batch/s, test_accuracy=0.929, test_loss=1.31, train_accuracy=0.938, train_lo
Epoch 17: 100%|█| 47/47 [00:01<00:00, 33.49batch/s, test_accuracy=0.92, test_loss=1.47, train_accuracy=0.926, train_los
Epoch 18: 100%|█| 47/47 [00:01<00:00, 34.99batch/s, test_accuracy=0.934, test_loss=1.22, train_accuracy=0.94, train_los
Epoch 19: 100%|█| 47/47 [00:01<00:00, 34.34batch/s, test_accuracy=0.909, test_loss=1.68, train_accuracy=0.913, train_lo
Epoch 20: 100%|█| 47/47 [00:01<00:00, 34.63batch/s, test_accuracy=0.927, test_loss=1.34, train_accuracy=0.943, train_lo
Epoch 21: 100%|█| 47/47 [00:01<00:00, 35.23batch/s, test_accuracy=0.886, test_loss=2.1, train_accuracy=0.898, train_los
Epoch 22: 100%|█| 47/47 [00:01<00:00, 31.07batch/s, test_accuracy=0.933, test_loss=1.23, train_accuracy=0.951, train_lo
Epoch 23: 100%|█| 47/47 [00:01<00:00, 34.50batch/s, test_accuracy=0.939, test_loss=1.12, train_accuracy=0.952, train_lo
Epoch 24: 100%|█| 47/47 [00:01<00:00, 34.34batch/s, test_accuracy=0.942, test_loss=1.07, train_accuracy=0.958, train_lo
Epoch 25: 100%|█| 47/47 [00:01<00:00, 34.11batch/s, test_accuracy=0.932, test_loss=1.25, train_accuracy=0.952, train_lo
Epoch 26: 100%|█| 47/47 [00:01<00:00, 33.95batch/s, test_accuracy=0.936, test_loss=1.18, train_accuracy=0.957, train_lo
Epoch 27: 100%|█| 47/47 [00:01<00:00, 34.12batch/s, test_accuracy=0.915, test_loss=1.56, train_accuracy=0.939, train_lo
Epoch 28: 100%|█| 47/47 [00:01<00:00, 33.37batch/s, test_accuracy=0.925, test_loss=1.38, train_accuracy=0.947, train_lo
Epoch 29: 100%|█| 47/47 [00:01<00:00, 32.59batch/s, test_accuracy=0.925, test_loss=1.38, train_accuracy=0.953, train_lo
Epoch 30: 100%|█| 47/47 [00:01<00:00, 33.33batch/s, test_accuracy=0.936, test_loss=1.18, train_accuracy=0.962, train_lo
Epoch 31: 100%|█| 47/47 [00:01<00:00, 35.62batch/s, test_accuracy=0.926, test_loss=1.36, train_accuracy=0.96, train_los
Epoch 32: 100%|█| 47/47 [00:01<00:00, 34.43batch/s, test_accuracy=0.93, test_loss=1.29, train_accuracy=0.957, train_los
Epoch 33: 100%|█| 47/47 [00:01<00:00, 33.92batch/s, test_accuracy=0.936, test_loss=1.18, train_accuracy=0.976, train_lo
Epoch 34: 100%|█| 47/47 [00:01<00:00, 34.72batch/s, test_accuracy=0.947, test_loss=0.976, train_accuracy=0.971, train_l
Epoch 35: 100%|█| 47/47 [00:01<00:00, 34.19batch/s, test_accuracy=0.931, test_loss=1.27, train_accuracy=0.969, train_lo
Epoch 36: 100%|█| 47/47 [00:01<00:00, 31.55batch/s, test_accuracy=0.944, test_loss=1.03, train_accuracy=0.977, train_lo
Epoch 37: 100%|█| 47/47 [00:01<00:00, 33.98batch/s, test_accuracy=0.952, test_loss=0.884, train_accuracy=0.983, train_l
Epoch 38: 100%|█| 47/47 [00:01<00:00, 35.29batch/s, test_accuracy=0.935, test_loss=1.2, train_accuracy=0.978, train_los
Epoch 39: 100%|█| 47/47 [00:01<00:00, 40.85batch/s, test_accuracy=0.949, test_loss=0.939, train_accuracy=0.985, train_l
Epoch 40: 100%|█| 47/47 [00:01<00:00, 44.46batch/s, test_accuracy=0.939, test_loss=1.12, train_accuracy=0.984, train_lo
In [23]:
# print best_train_loss
best_train_loss
Out[23]:
[16.02347237663499,
 11.856510391290234,
 8.54257891546377,
 7.668069220454009,
 5.102840781758711,
 4.111729794080982,
 2.6051183546080248,
 2.4885170619400565,
 1.9699481550746192,
 2.0681387173213293,
 1.9515374246533614,
 1.7490193900195217,
 1.6324180973515536,
 1.5465013553856826,
 1.6446919176323926,
 1.1476021962584233,
 1.3623940511731014,
 1.1107807354159072,
 1.5986650915792473,
 1.0524800890819228,
 1.8686891377576997,
 0.9021257906416481,
 0.8775781500799706,
 0.7732506776928413,
 0.8898519703608094,
 0.8008667733247284,
 1.1322599209073747,
 0.9819056224671,
 0.865304329799132,
 0.6996077560078087,
 0.7425661269907444,
 0.7855244979736801,
 0.4449259851804047,
 0.5277742720760663,
 0.5768695531994212,
 0.4173098895485175,
 0.32218778237201723,
 0.4019676141974691,
 0.27309250124866224,
 0.30377705195075905]
In [24]:
# plotting
fig,(ax1,ax2) = plt.subplots(1,2, figsize=(15,5))

ax1.plot(range(epochs),best_train_loss,label=f'Training Loss')
ax1.plot(range(epochs),best_test_loss,label=f'Test Loss')
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss")
ax1.grid()

ax2.plot(range(epochs),best_train_acc,label=f'Training accuracy')
ax2.plot(range(epochs),best_test_acc,label=f'Test accuracy')
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Accuracy")
ax2.grid()

ax1.legend()
ax2.legend()
plt.suptitle(f"Outputs from MLP, with best learning rate = {best_lr:.4f}")
plt.show()
In [25]:
### adding the print-outs of final losses
print("The final losses of training and test data are: ",best_train_loss[-1], best_test_loss[-1], "repectively.")
print("The final accuracies of training and test data are: ",best_train_acc[-1], best_test_acc[-1], "repectively.")
The final losses of training and test data are:  0.30377705195075905 1.1230545556967457 repectively.
The final accuracies of training and test data are:  0.9835 0.939 repectively.

1.1.3 train the MLP with reduced width of layers: 50 neurons¶

In [26]:
# retrain using optimal learning rate
best_lr = learning_rates[np.argmin(train_losses)]

width = 50
epochs=40
batch_size = 128

best_w_2,best_b_2, best_train_loss_2, best_train_acc_2, best_test_loss_2,best_test_acc_2 = train(x_train, y_train, x_test, y_test, best_lr, epochs, width=width, batch_size= batch_size)
Epoch 1: 100%|█| 47/47 [00:00<00:00, 168.91batch/s, test_accuracy=0.459, test_loss=9.96, train_accuracy=0.452, train_lo
Epoch 2: 100%|█| 47/47 [00:00<00:00, 141.95batch/s, test_accuracy=0.673, test_loss=6.02, train_accuracy=0.677, train_lo
Epoch 3: 100%|█| 47/47 [00:00<00:00, 143.24batch/s, test_accuracy=0.649, test_loss=6.46, train_accuracy=0.648, train_lo
Epoch 4: 100%|█| 47/47 [00:00<00:00, 161.75batch/s, test_accuracy=0.852, test_loss=2.72, train_accuracy=0.864, train_lo
Epoch 5: 100%|█| 47/47 [00:00<00:00, 166.55batch/s, test_accuracy=0.873, test_loss=2.34, train_accuracy=0.881, train_lo
Epoch 6: 100%|█| 47/47 [00:00<00:00, 151.04batch/s, test_accuracy=0.885, test_loss=2.12, train_accuracy=0.891, train_lo
Epoch 7: 100%|█| 47/47 [00:00<00:00, 161.94batch/s, test_accuracy=0.903, test_loss=1.79, train_accuracy=0.913, train_lo
Epoch 8: 100%|█| 47/47 [00:00<00:00, 162.50batch/s, test_accuracy=0.914, test_loss=1.58, train_accuracy=0.92, train_los
Epoch 9: 100%|█| 47/47 [00:00<00:00, 165.59batch/s, test_accuracy=0.901, test_loss=1.82, train_accuracy=0.915, train_lo
Epoch 10: 100%|█| 47/47 [00:00<00:00, 116.65batch/s, test_accuracy=0.913, test_loss=1.6, train_accuracy=0.917, train_lo
Epoch 11: 100%|█| 47/47 [00:00<00:00, 159.21batch/s, test_accuracy=0.929, test_loss=1.31, train_accuracy=0.929, train_l
Epoch 12: 100%|█| 47/47 [00:00<00:00, 170.74batch/s, test_accuracy=0.911, test_loss=1.64, train_accuracy=0.911, train_l
Epoch 13: 100%|█| 47/47 [00:00<00:00, 137.39batch/s, test_accuracy=0.919, test_loss=1.49, train_accuracy=0.935, train_l
Epoch 14: 100%|█| 47/47 [00:00<00:00, 167.71batch/s, test_accuracy=0.934, test_loss=1.22, train_accuracy=0.949, train_l
Epoch 15: 100%|█| 47/47 [00:00<00:00, 168.31batch/s, test_accuracy=0.924, test_loss=1.4, train_accuracy=0.946, train_lo
Epoch 16: 100%|█| 47/47 [00:00<00:00, 170.61batch/s, test_accuracy=0.937, test_loss=1.16, train_accuracy=0.956, train_l
Epoch 17: 100%|█| 47/47 [00:00<00:00, 174.54batch/s, test_accuracy=0.924, test_loss=1.4, train_accuracy=0.954, train_lo
Epoch 18: 100%|█| 47/47 [00:00<00:00, 165.94batch/s, test_accuracy=0.943, test_loss=1.05, train_accuracy=0.965, train_l
Epoch 19: 100%|█| 47/47 [00:00<00:00, 167.11batch/s, test_accuracy=0.935, test_loss=1.2, train_accuracy=0.965, train_lo
Epoch 20: 100%|█| 47/47 [00:00<00:00, 160.84batch/s, test_accuracy=0.932, test_loss=1.25, train_accuracy=0.967, train_l
Epoch 21: 100%|█| 47/47 [00:00<00:00, 162.17batch/s, test_accuracy=0.901, test_loss=1.82, train_accuracy=0.926, train_l
Epoch 22: 100%|█| 47/47 [00:00<00:00, 155.02batch/s, test_accuracy=0.939, test_loss=1.12, train_accuracy=0.966, train_l
Epoch 23: 100%|█| 47/47 [00:00<00:00, 163.63batch/s, test_accuracy=0.88, test_loss=2.21, train_accuracy=0.901, train_lo
Epoch 24: 100%|█| 47/47 [00:00<00:00, 182.66batch/s, test_accuracy=0.937, test_loss=1.16, train_accuracy=0.973, train_l
Epoch 25: 100%|█| 47/47 [00:00<00:00, 160.29batch/s, test_accuracy=0.938, test_loss=1.14, train_accuracy=0.979, train_l
Epoch 26: 100%|█| 47/47 [00:00<00:00, 156.56batch/s, test_accuracy=0.934, test_loss=1.22, train_accuracy=0.978, train_l
Epoch 27: 100%|█| 47/47 [00:00<00:00, 171.37batch/s, test_accuracy=0.933, test_loss=1.23, train_accuracy=0.977, train_l
Epoch 28: 100%|█| 47/47 [00:00<00:00, 173.26batch/s, test_accuracy=0.933, test_loss=1.23, train_accuracy=0.977, train_l
Epoch 29: 100%|█| 47/47 [00:00<00:00, 166.36batch/s, test_accuracy=0.939, test_loss=1.12, train_accuracy=0.984, train_l
Epoch 30: 100%|█| 47/47 [00:00<00:00, 160.29batch/s, test_accuracy=0.942, test_loss=1.07, train_accuracy=0.986, train_l
Epoch 31: 100%|█| 47/47 [00:00<00:00, 159.13batch/s, test_accuracy=0.937, test_loss=1.16, train_accuracy=0.985, train_l
Epoch 32: 100%|█| 47/47 [00:00<00:00, 119.00batch/s, test_accuracy=0.937, test_loss=1.16, train_accuracy=0.986, train_l
Epoch 33: 100%|█| 47/47 [00:00<00:00, 154.51batch/s, test_accuracy=0.927, test_loss=1.34, train_accuracy=0.981, train_l
Epoch 34: 100%|█| 47/47 [00:00<00:00, 167.71batch/s, test_accuracy=0.937, test_loss=1.16, train_accuracy=0.991, train_l
Epoch 35: 100%|█| 47/47 [00:00<00:00, 157.40batch/s, test_accuracy=0.945, test_loss=1.01, train_accuracy=0.994, train_l
Epoch 36: 100%|█| 47/47 [00:00<00:00, 164.78batch/s, test_accuracy=0.934, test_loss=1.22, train_accuracy=0.987, train_l
Epoch 37: 100%|█| 47/47 [00:00<00:00, 166.25batch/s, test_accuracy=0.941, test_loss=1.09, train_accuracy=0.994, train_l
Epoch 38: 100%|█| 47/47 [00:00<00:00, 166.52batch/s, test_accuracy=0.935, test_loss=1.2, train_accuracy=0.989, train_lo
Epoch 39: 100%|█| 47/47 [00:00<00:00, 170.75batch/s, test_accuracy=0.937, test_loss=1.16, train_accuracy=0.993, train_l
Epoch 40: 100%|█| 47/47 [00:00<00:00, 152.51batch/s, test_accuracy=0.944, test_loss=1.03, train_accuracy=0.997, train_l
In [27]:
# print best_train_loss_2
best_train_loss_2
Out[27]:
[10.098285636060082,
 5.946665926066375,
 6.47750865321265,
 2.513064702501734,
 2.1816715549190877,
 2.000632705776716,
 1.5925281814388281,
 1.4697899786304403,
 1.5649120858069405,
 1.5188852597537956,
 1.3040934048391173,
 1.6446919176323926,
 1.1997659324519878,
 0.9450841616245838,
 0.9880425326075194,
 0.816209048675777,
 0.8560989645885029,
 0.6535809299546634,
 0.6474440198142442,
 0.6014171937610987,
 1.3531886859624722,
 0.6259648343227762,
 1.8287992218449738,
 0.5032266315143887,
 0.3927622489868399,
 0.4111729794080981,
 0.432652164899566,
 0.4203783446187272,
 0.2915032316699203,
 0.26388713603803315,
 0.2792294113890816,
 0.2516133157571944,
 0.34673542293369475,
 0.15955966365090377,
 0.11966974773817782,
 0.23933949547635563,
 0.11046438252754877,
 0.20251803463383938,
 0.12887511294880688,
 0.05830064633398407]
In [28]:
fig,(ax1,ax2) = plt.subplots(1,2, figsize=(15,5))

ax1.plot(range(40),best_train_acc,label=f'Training accuracy for width = 200')
ax1.plot(range(40),best_test_acc,label=f'Test accuracy for width = 200')
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Accuracy")
ax1.grid()

ax2.plot(range(40),best_train_acc_2,label=f'Training accuracy for width = 50')
ax2.plot(range(40),best_test_acc_2,label=f'Test accuracy for width = 50')
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Accuracy")
ax2.grid()

ax1.legend()
ax2.legend()
plt.suptitle(f"Effect of Increasing Epoches on Accuracies, with best learning rate = {best_lr:.4f}")
plt.show()

print("For width = 200:")
print("Final Train Accuracy: ", best_train_acc[-1])
print("Final Test Accuracy: ", best_test_acc[-1])
print("\nFor width = 50:")
print("Final Train Accuracy: ", best_train_acc_2[-1])
print("Final Test Accuracy: ", best_test_acc_2[-1])
For width = 200:
Final Train Accuracy:  0.9835
Final Test Accuracy:  0.939

For width = 50:
Final Train Accuracy:  0.9968333333333333
Final Test Accuracy:  0.944
In [29]:
fig, ax = plt.subplots(1, 1, figsize=(10, 5))

ax.plot(range(40), best_train_loss_2, label=f'Training loss for width = 50')
ax.plot(range(40), best_test_loss_2, label=f'Test loss for width = 50')
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.grid()
ax.legend()
plt.show()

print("The final losses of training and test data are: ",best_train_loss_2[-1], best_test_loss_2[-1], "repectively.")
The final losses of training and test data are:  0.05830064633398407 1.031000903590455 repectively.

Explanation:¶

The width of a hidden layer refers to the number of neurons or units in that layer. Similar to the model complexity for linear regression, the more neurons per layer, the more complex patterns the MLP can learn from the data.

From the print-outs, the final loss for the training and test data set are both higher for reduced-width model. This could mean that the model in 1.2.2 performs better, with better fitted test data. In addition, it can be observed that the accuracies of training data and test data are closer to each other. This indicates that model in 1.2.2 fits the unseen data better

However, for both models, the accuracies for the training and test data are increasing, but the accuracies for training and particularly the test data remains stable at width=50. This could be the model with reduced part has had converged.

Given more time, a tolerance can be set to return a signal once the test accuracy and test loss falls within certain criterion for both models.

1.1.4 Retrain MLP with dropout rate (regularization)¶

In [30]:
# retrain MLP using optimal learning rate
# retrain using optimal learning rate

lr = learning_rates[np.argmin(train_losses)]
epochs = 50  # scaled epochs
batch_size = 128
dropout_prob = 0.2

w_dropout,b_dropout,train_loss_dropout, train_acc_dropout,test_loss_dropout,test_acc_dropout = train(x_train, y_train, x_test, y_test, lr, epochs,batch_size,dropout_prob=dropout_prob)
Epoch 1: 100%|█| 47/47 [00:01<00:00, 45.84batch/s, test_accuracy=0.134, test_loss=15.9, train_accuracy=0.131, train_los
Epoch 2: 100%|█| 47/47 [00:01<00:00, 46.44batch/s, test_accuracy=0.403, test_loss=11, train_accuracy=0.404, train_loss=
Epoch 3: 100%|█| 47/47 [00:01<00:00, 45.38batch/s, test_accuracy=0.69, test_loss=5.71, train_accuracy=0.692, train_loss
Epoch 4: 100%|█| 47/47 [00:01<00:00, 45.08batch/s, test_accuracy=0.721, test_loss=5.14, train_accuracy=0.735, train_los
Epoch 5: 100%|█| 47/47 [00:01<00:00, 46.80batch/s, test_accuracy=0.845, test_loss=2.85, train_accuracy=0.841, train_los
Epoch 6: 100%|█| 47/47 [00:01<00:00, 46.83batch/s, test_accuracy=0.848, test_loss=2.8, train_accuracy=0.85, train_loss=
Epoch 7: 100%|█| 47/47 [00:01<00:00, 46.55batch/s, test_accuracy=0.882, test_loss=2.17, train_accuracy=0.876, train_los
Epoch 8: 100%|█| 47/47 [00:01<00:00, 46.68batch/s, test_accuracy=0.888, test_loss=2.06, train_accuracy=0.888, train_los
Epoch 9: 100%|█| 47/47 [00:00<00:00, 48.02batch/s, test_accuracy=0.876, test_loss=2.28, train_accuracy=0.887, train_los
Epoch 10: 100%|█| 47/47 [00:01<00:00, 46.00batch/s, test_accuracy=0.885, test_loss=2.12, train_accuracy=0.894, train_lo
Epoch 11: 100%|█| 47/47 [00:00<00:00, 47.89batch/s, test_accuracy=0.906, test_loss=1.73, train_accuracy=0.896, train_lo
Epoch 12: 100%|█| 47/47 [00:00<00:00, 47.49batch/s, test_accuracy=0.902, test_loss=1.8, train_accuracy=0.902, train_los
Epoch 13: 100%|█| 47/47 [00:00<00:00, 47.07batch/s, test_accuracy=0.908, test_loss=1.69, train_accuracy=0.911, train_lo
Epoch 14: 100%|█| 47/47 [00:01<00:00, 44.17batch/s, test_accuracy=0.921, test_loss=1.45, train_accuracy=0.917, train_lo
Epoch 15: 100%|█| 47/47 [00:01<00:00, 46.10batch/s, test_accuracy=0.916, test_loss=1.55, train_accuracy=0.914, train_lo
Epoch 16: 100%|█| 47/47 [00:01<00:00, 46.74batch/s, test_accuracy=0.919, test_loss=1.49, train_accuracy=0.926, train_lo
Epoch 17: 100%|█| 47/47 [00:01<00:00, 46.50batch/s, test_accuracy=0.915, test_loss=1.56, train_accuracy=0.919, train_lo
Epoch 18: 100%|█| 47/47 [00:00<00:00, 47.82batch/s, test_accuracy=0.918, test_loss=1.51, train_accuracy=0.924, train_lo
Epoch 19: 100%|█| 47/47 [00:01<00:00, 45.38batch/s, test_accuracy=0.924, test_loss=1.4, train_accuracy=0.927, train_los
Epoch 20: 100%|█| 47/47 [00:01<00:00, 38.98batch/s, test_accuracy=0.933, test_loss=1.23, train_accuracy=0.932, train_lo
Epoch 21: 100%|█| 47/47 [00:01<00:00, 45.84batch/s, test_accuracy=0.925, test_loss=1.38, train_accuracy=0.926, train_lo
Epoch 22: 100%|█| 47/47 [00:01<00:00, 45.01batch/s, test_accuracy=0.933, test_loss=1.23, train_accuracy=0.935, train_lo
Epoch 23: 100%|█| 47/47 [00:01<00:00, 46.06batch/s, test_accuracy=0.933, test_loss=1.23, train_accuracy=0.942, train_lo
Epoch 24: 100%|█| 47/47 [00:01<00:00, 45.89batch/s, test_accuracy=0.938, test_loss=1.14, train_accuracy=0.94, train_los
Epoch 25: 100%|█| 47/47 [00:01<00:00, 45.56batch/s, test_accuracy=0.929, test_loss=1.31, train_accuracy=0.939, train_lo
Epoch 26: 100%|█| 47/47 [00:01<00:00, 44.18batch/s, test_accuracy=0.934, test_loss=1.22, train_accuracy=0.94, train_los
Epoch 27: 100%|█| 47/47 [00:01<00:00, 45.98batch/s, test_accuracy=0.937, test_loss=1.16, train_accuracy=0.948, train_lo
Epoch 28: 100%|█| 47/47 [00:01<00:00, 45.18batch/s, test_accuracy=0.936, test_loss=1.18, train_accuracy=0.947, train_lo
Epoch 29: 100%|█| 47/47 [00:01<00:00, 43.39batch/s, test_accuracy=0.94, test_loss=1.1, train_accuracy=0.949, train_loss
Epoch 30: 100%|█| 47/47 [00:01<00:00, 44.37batch/s, test_accuracy=0.942, test_loss=1.07, train_accuracy=0.951, train_lo
Epoch 31: 100%|█| 47/47 [00:01<00:00, 43.85batch/s, test_accuracy=0.941, test_loss=1.09, train_accuracy=0.953, train_lo
Epoch 32: 100%|█| 47/47 [00:01<00:00, 45.51batch/s, test_accuracy=0.937, test_loss=1.16, train_accuracy=0.953, train_lo
Epoch 33: 100%|█| 47/47 [00:01<00:00, 44.97batch/s, test_accuracy=0.941, test_loss=1.09, train_accuracy=0.955, train_lo
Epoch 34: 100%|█| 47/47 [00:01<00:00, 44.16batch/s, test_accuracy=0.937, test_loss=1.16, train_accuracy=0.955, train_lo
Epoch 35: 100%|█| 47/47 [00:01<00:00, 43.88batch/s, test_accuracy=0.939, test_loss=1.12, train_accuracy=0.952, train_lo
Epoch 36: 100%|█| 47/47 [00:01<00:00, 46.57batch/s, test_accuracy=0.94, test_loss=1.1, train_accuracy=0.958, train_loss
Epoch 37: 100%|█| 47/47 [00:01<00:00, 45.73batch/s, test_accuracy=0.94, test_loss=1.1, train_accuracy=0.961, train_loss
Epoch 38: 100%|█| 47/47 [00:01<00:00, 44.71batch/s, test_accuracy=0.941, test_loss=1.09, train_accuracy=0.955, train_lo
Epoch 39: 100%|█| 47/47 [00:01<00:00, 45.30batch/s, test_accuracy=0.947, test_loss=0.976, train_accuracy=0.963, train_l
Epoch 40: 100%|█| 47/47 [00:01<00:00, 41.96batch/s, test_accuracy=0.947, test_loss=0.976, train_accuracy=0.963, train_l
Epoch 41: 100%|█| 47/47 [00:01<00:00, 44.53batch/s, test_accuracy=0.941, test_loss=1.09, train_accuracy=0.961, train_lo
Epoch 42: 100%|█| 47/47 [00:01<00:00, 34.88batch/s, test_accuracy=0.943, test_loss=1.05, train_accuracy=0.964, train_lo
Epoch 43: 100%|█| 47/47 [00:01<00:00, 32.89batch/s, test_accuracy=0.944, test_loss=1.03, train_accuracy=0.965, train_lo
Epoch 44: 100%|█| 47/47 [00:01<00:00, 27.37batch/s, test_accuracy=0.939, test_loss=1.12, train_accuracy=0.965, train_lo
Epoch 45: 100%|█| 47/47 [00:02<00:00, 23.36batch/s, test_accuracy=0.945, test_loss=1.01, train_accuracy=0.967, train_lo
Epoch 46: 100%|█| 47/47 [00:01<00:00, 25.63batch/s, test_accuracy=0.942, test_loss=1.07, train_accuracy=0.97, train_los
Epoch 47: 100%|█| 47/47 [00:01<00:00, 29.98batch/s, test_accuracy=0.946, test_loss=0.994, train_accuracy=0.97, train_lo
Epoch 48: 100%|█| 47/47 [00:01<00:00, 31.86batch/s, test_accuracy=0.942, test_loss=1.07, train_accuracy=0.968, train_lo
Epoch 49: 100%|█| 47/47 [00:01<00:00, 34.30batch/s, test_accuracy=0.945, test_loss=1.01, train_accuracy=0.971, train_lo
Epoch 50: 100%|█| 47/47 [00:01<00:00, 33.27batch/s, test_accuracy=0.948, test_loss=0.957, train_accuracy=0.971, train_l
In [31]:
# print train_loss with drop-out rates
print(train_loss_dropout)
[16.001993191143523, 10.98200069628047, 5.673573424817713, 4.888048926844032, 2.9273061369800417, 2.767746473329138, 2.285999027306217, 2.0681387173213293, 2.0865494477425877, 1.9576743347937808, 1.9085790536704257, 1.8073200363535058, 1.6354865524217632, 1.5311590800346342, 1.5802543611579891, 1.3623940511731012, 1.4851322539814888, 1.3930786018751982, 1.3439833207518432, 1.2488612135753427, 1.356257141032682, 1.1874921121711492, 1.0616854542925518, 1.1046438252754875, 1.1261230107669553, 1.101575370205278, 0.9573579819054225, 0.9849740775373096, 0.9420157065543742, 0.9113311558522773, 0.865304329799132, 0.8622358747289223, 0.8376882341672447, 0.834619779097035, 0.892920425431019, 0.7701822226226317, 0.718018486429067, 0.8254144138864061, 0.68733393572697, 0.6904023907971797, 0.7241553965694862, 0.6689232053057119, 0.6413071096738248, 0.6535809299546635, 0.6044856488313084, 0.5492534575675341, 0.5553903677079536, 0.5922118285504697, 0.5431165474271147, 0.5277742720760663]
In [32]:
fig,(ax1,ax2) = plt.subplots(1,2, figsize=(15,5))

ax1.plot(range(epochs),train_loss_dropout,label=f'Training Loss')
ax1.plot(range(epochs),test_loss_dropout,label=f'Test Loss')
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss")
ax1.grid()

ax2.plot(range(epochs),train_acc_dropout,label=f'Training accuracy')
ax2.plot(range(epochs),test_acc_dropout,label=f'Test accuracy')
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Accuracy")
ax2.grid()

ax1.legend()
ax2.legend()
plt.suptitle(f"Outputs from MLP, with dropout rate = {dropout_prob} and learning rate = {lr: 4f}")
plt.show()

print("Final Dropout Train Loss: ", train_loss_dropout[-1])
print("Final Dropout Test Loss: ", test_loss_dropout[-1])
print("Final Dropout Train Accuracy: ", train_acc_dropout[-1])
print("Final Dropout Test Accuracy: ", test_acc_dropout[-1])
Final Dropout Train Loss:  0.5277742720760663
Final Dropout Test Loss:  0.9573579819054225
Final without Dropout Train Accuracy:  0.9713333333333334
Final without Dropout Test Accuracy:  0.948
In [33]:
# the first batch
x_batch_test = x_test[:128]

# no dropout
outputs_nd = forward_prop(x_batch_test, best_w, best_b)

# with dropout = 0.2
outputs_d = forward_prop(x_batch_test, best_w, best_b, dropout_prob = 0.2)

# plotting activations of the first hidden layer
activations_nd = outputs_nd[1]['A1'].flatten()
activations_d = outputs_d[1]['A1'].flatten()

plt.hist(outputs_nd[1]['A1'].flatten(), bins=100, density=False, alpha=0.5, label="Without Dropout")
plt.hist(outputs_d[1]['A1'].flatten(), bins=100, density=False, alpha=0.5, label="With Dropout Rate = 0.2")
plt.legend()
plt.show()

print("Sum of activations in MLP without dropout: ", np.sum(outputs_nd[1]['A1']))
print("Sum of activations in MLP with dropout rate = 0.2: ", np.sum(outputs_d[1]['A1']))

print("\nPercentage of zero in activations with dropout:", (len(activations_d[activations_d<1e-10])/len(activations_d))*100,'%')
Sum of activations in MLP without dropout:  20555.677530687786
Sum of activations in MLP with dropout rate = 0.2:  20618.161134698887

Percentage of zero in activations with dropout: 19.78125 %

Expanation:¶

  • Effect of dropout:

Dropout is a technique of regularization to prevent overfitting. It randomly replaces elemts of each col of $W$ with zeros with dropout_prob. By doing this, the capacity of the model is reduced and it prevents the neurons from co-adapting each other too much. Neurons in the network have less dependence on other specific neurons being present, and so each neuron learns features that are more robust, and generalises better.

  • In terms of loss and accuracy plots:

It can be observed from the plots without doubt that the model performs better with dropout:

$\cdot$ less fluctuating losses and accuracies,

$\cdot$ much narrower gaps for training and test data from smaller epochs onwards and,

$\cdot$ higher test accuracies(from $\boldsymbol {0.939}$ to $\boldsymbol {0.948}$).

They indicate that the model reduces disruption of noise and could have converged, and the model generalizes more robustly to unseen data.

  • Comparison between the histograms:

References: http://proceedings.mlr.press/v48/gal16.pdf http://proceedings.mlr.press/v31/damianou13a.pdf

The histogram with dropout is more Gaussian distributed than that without dropout. This can be interpreted in terms of Deep Gaussian process:

Deep Gaussian process is a probablistic substitute for MLP. And in the scenario of this part, the MLP with 1 hidden layer is equivalent to a standard Gaussian process (from refs).

The approach to Gaussian process modelling is to place a prior directly over a class of functions and integrate them out. In MLP, dropping-out can be viewed as a Baysian model averaging, which processes inputs into subsets of neurons and are corresponding to non-zero assignments of Bernoulli mask. By doing so, dropout can reduce the co-adapting of the network on individual neurons, making it learn more distributed representations. This latter is analogous to the former, in the way that a Gaussian process places a prior over functions, encouraging them to be smooth and avoiding overfitting to noisy data.

Therefore, using dropout, the first hidden layer activations is more like a Gaussian process, with smooth and well-behaved functions that generalize well to new data. This corresponds to the behaviour of the orange histogram (without considerations of dropped zeros). And a better generalization corresponds to the trend shown in the loss and accuracy plots.


1.2 Dimensionality reduction (20 marks)¶

For various applications of NMF and PCA, denoising is applied in this section.

For given image matrices, PCA is implemented to see the trend of variance explained by principle components (eigenvectors of the matrix) and for both methods, the first 10 basis components are visualized and differences are explained.

Due to the different natures of the decomposition approaches, effects of denoising for 100 pc are discussed for both methods by reconstructed images. One important thing to notice is that: in the process of training, noisy images have been normalized by first being divided by 255 and then by standard procedure of standardization. Therefore, to compare reconstructed images and original noisy and unnoisy images, reconstructed images should be reverted back by multiplying 'sigma' and adding 'mu' stored ealier. After this step, all images are at the same scale.

Finally, MSE is used as a measure of performance of PCA to see how it behaves wrt the test noisy data and the unnoisy data. Similarly, the reconsrtructed images should be reverted to its original scale by the same steps mentioned above.

In [34]:
# read txt file
MNIST_train_noisy = np.loadtxt('MNIST_train_noisy.txt')
MNIST_test_noisy = np.loadtxt('MNIST_test_noisy.txt')

# inspecting the data
print("The shapes of the training and test data are: ", MNIST_train_noisy.shape, MNIST_test_noisy.shape)
print("The type of the MNIST_train_noisy and MNIST_test_noisy are: ", type(MNIST_test_noisy),  type(MNIST_train_noisy))

print("The max and min of the trainning noisy data: ", np.min(MNIST_train_noisy), np.max(MNIST_train_noisy))
The shapes of the training and test data are:  (6000, 784) (1000, 784)
The type of the MNIST_train_noisy and MNIST_test_noisy are:  <class 'numpy.ndarray'> <class 'numpy.ndarray'>
The max and min of the trainning noisy data:  -107.31 343.8
In [35]:
MNIST_train_noisy[0]  # intergers with Gaussian noise
Out[35]:
array([ 1.46928e+01,  4.15443e+01,  6.50907e-01,  1.84006e+00,
       -1.37111e+01,  2.70121e+00,  6.83062e+00,  2.09018e+00,
        1.32685e+01, -1.81447e+01,  7.60745e+00, -4.30438e+01,
       -8.82577e+00,  1.29603e+01, -1.42484e+01, -2.25878e+01,
       -3.95282e+01, -3.76398e+00, -4.38008e+00,  1.06019e+01,
        2.69545e+01,  1.72741e+00, -2.04378e+01,  7.15860e+00,
        1.14836e+01,  8.48694e+00,  2.12904e+01,  3.88620e+00,
       -2.23335e+01, -1.54388e+01, -3.57090e+01, -5.21508e+00,
       -1.84571e+01, -1.35396e+01,  2.37871e+01,  1.47583e+01,
       -1.28852e+01, -6.30926e+00, -8.74880e+00, -1.16587e+01,
        7.42091e-01,  6.38451e+00, -1.37609e+01, -1.92919e+01,
        2.79193e+01,  6.01952e+00, -7.07875e+00,  1.47998e+00,
        8.53116e+00, -1.10761e+01, -1.63237e+01,  1.98356e+01,
        4.32072e+00,  4.94797e+01,  1.72346e+01, -1.43560e+01,
        9.83602e+00, -9.93811e+00, -2.12885e+01, -6.20323e+00,
       -2.99198e+01,  1.86351e+01,  9.14288e+00,  2.57839e+01,
        2.11070e+01,  3.32087e+01,  1.57824e+00, -1.72484e+01,
       -1.39058e+01, -2.15084e+01,  5.44566e+01, -1.97676e+00,
        3.61940e+00,  1.54496e+01,  1.26325e+01, -2.08133e+01,
       -9.79608e+00,  4.42669e+01,  1.40790e+01,  1.54795e+01,
       -7.34956e+00,  1.70580e+01, -2.84354e+01, -2.64093e+01,
       -1.03846e+01, -1.28380e+01, -2.95061e+01, -1.24579e+00,
        1.59867e+01,  9.92727e+00, -3.32957e-01, -2.42489e+01,
       -1.33814e+01,  7.15406e+00,  7.60305e+00, -1.74021e+01,
       -2.56823e+01, -2.96517e+01,  1.03331e+01,  1.31809e+01,
        6.63456e+00, -9.38165e+00, -3.31801e+01, -3.04959e-01,
        5.39579e+01, -6.06997e+00, -3.28984e+01,  3.67116e+00,
       -1.18050e+00,  4.47812e+00,  3.83445e+01,  6.34854e+00,
       -1.45135e+01, -5.61027e+00,  1.48787e+01,  2.09539e+01,
        4.73510e-01, -1.81998e+01, -2.59578e+01, -1.35279e+01,
       -7.65255e+00,  1.41188e+01,  1.87509e+00,  2.72302e+01,
       -1.04744e+01,  1.91733e+01,  2.74841e+01, -2.85886e+01,
       -1.35327e+01, -1.90314e+01,  1.18274e+01,  2.56375e+01,
        1.93300e+01,  1.84234e+01,  7.69585e+00, -4.21003e+01,
        1.64824e+01, -7.53810e+00,  1.66392e+01, -7.31290e+00,
       -7.70799e+00, -3.15550e+00,  1.93177e+01, -2.63962e+01,
        1.96105e+00, -1.28918e+01,  1.60156e+01, -1.19533e+01,
        1.80504e+01,  5.75270e+00, -9.36589e+00,  8.39003e+00,
       -3.04006e+01, -2.74955e+01,  2.28516e+01,  2.19562e+02,
        1.97703e+02,  6.60478e+00, -9.21046e+00, -8.67712e+00,
        7.40744e+00,  3.67294e+00, -1.69200e+01, -3.03753e+00,
       -1.64380e+01,  1.94649e+01, -4.17285e+01,  1.83169e+01,
        1.21473e+01,  1.69646e+01, -4.00270e+00, -2.88538e+00,
        5.05098e+00, -1.12919e+01, -4.43433e+01, -3.31128e+01,
       -2.34549e+01,  2.34587e+01, -9.08712e-01, -2.25376e+01,
       -3.62749e+01,  2.39294e+01,  1.25245e+02,  2.55935e+02,
        1.89210e+02,  3.83215e+01, -3.21856e+01, -1.65434e+01,
       -2.95512e+01, -2.49321e+01,  2.25440e+01,  2.31692e+01,
        1.36295e+01, -1.04602e+00,  5.22427e+00, -1.37121e+00,
       -7.52505e+00,  1.18560e+01, -7.37543e+00,  2.58198e+01,
        1.03998e+01, -2.96805e+01, -2.07470e+00,  9.65127e+00,
       -2.05366e+01, -8.97081e+00,  2.15243e+01, -9.38292e+00,
        1.35506e+01,  6.92266e+00,  1.62373e+02,  2.70264e+02,
        2.45413e+02,  4.78852e+01,  2.77413e+01,  1.34091e+01,
       -8.51889e+00, -2.64588e+01, -5.93605e+00,  2.32397e+01,
        2.16447e+01,  2.99154e+01,  2.39283e-01,  9.20799e+00,
       -4.50518e+00,  2.87857e+00, -1.30913e+01,  3.20019e+01,
        4.33217e+01, -8.93203e+00,  2.51179e+01,  4.06348e+00,
        5.54683e+00, -8.03559e+00, -1.85558e+00, -2.13523e+01,
        2.53617e+01, -4.36704e+01,  1.88918e+02,  3.00853e+02,
        2.32465e+02,  1.40830e+02,  9.86603e+00, -3.40927e+01,
       -8.55709e+00,  8.53139e+00,  2.64744e+00,  1.82964e+01,
       -1.90327e+01,  3.55915e+01,  1.77847e+01,  1.47957e+01,
        6.43721e+00,  2.03183e+00,  1.01528e+01,  5.51030e+00,
       -9.87462e+00,  9.83367e+00,  1.12899e+01,  1.90928e+01,
        9.20189e+00,  2.73134e+01,  3.64131e+01,  2.11858e+01,
       -7.10387e+00,  3.61107e+01,  2.45559e+02,  2.87364e+02,
        2.39861e+02,  1.77156e+02,  2.87905e+01,  1.07451e+01,
        2.13438e+01, -1.34293e+01,  2.14320e+01,  2.61130e+01,
       -3.71112e+01, -1.03756e+01, -7.07590e-01,  2.06512e+01,
        8.28111e+00, -3.65680e+01, -1.18821e+01,  4.34086e+00,
       -2.44816e+01,  3.04381e+00, -6.55714e+00, -1.43526e+01,
       -6.57490e+00,  1.49190e+00, -3.34429e+01, -1.75223e+01,
        1.19131e+01,  2.08288e+02,  2.79990e+02,  2.42012e+02,
        2.17104e+02,  1.96761e+02,  1.43511e+01,  2.13817e+01,
       -2.17823e+01, -2.00719e+01,  7.89365e-01,  6.57631e+00,
       -4.68865e+00, -7.88515e+00, -3.64777e+01, -4.27916e+01,
       -1.23525e+01,  4.56614e+00,  3.53103e-01,  1.16850e+01,
        1.72275e+01,  1.31295e+01,  2.38088e+01, -2.48391e+01,
        3.24215e+01, -2.36893e+01,  1.61307e+01, -2.42744e+01,
        6.36573e+01,  2.18915e+02,  1.50615e+02,  1.40312e+02,
        2.70070e+02,  1.79195e+02, -4.44998e+00, -2.29097e+01,
       -1.55775e+01,  2.20191e+00, -2.21838e+01,  1.84893e+00,
       -1.84677e+01, -2.17944e+00,  2.30816e+01,  3.12202e+01,
       -2.32265e+01,  7.87255e+00,  1.69636e+01, -2.93580e+01,
       -4.90506e+00, -3.79856e+01,  2.45714e+01, -6.42660e+00,
       -2.06966e+01,  1.01039e+01, -1.17117e+00,  3.78489e+01,
        2.60123e+02,  2.72533e+02,  1.40536e+02,  7.33374e+01,
        2.58281e+02,  1.64409e+02,  3.76308e+01,  4.11274e+00,
        1.10785e+01,  2.19105e+01,  1.34282e+01, -3.59380e+00,
        1.01468e+01, -2.86313e+01,  8.58247e+00,  7.98355e-01,
       -5.65881e+00, -1.65159e+01, -2.48610e+01,  2.16931e+01,
        1.19902e+01, -3.43677e+01,  6.95947e-01,  1.39839e+01,
        2.26090e+01,  1.36541e+01,  2.45212e+01,  2.14205e+02,
        2.44182e+02,  2.15888e+02,  2.22461e+01,  1.13725e+02,
        2.50662e+02,  1.74208e+02, -7.26047e+00,  6.55291e+00,
        3.30733e+00,  3.06851e+00,  2.28504e+01,  5.80103e+00,
       -3.66832e+01,  4.71170e+00,  1.17757e+01,  1.62027e+01,
       -1.55880e+00,  1.76639e+01,  1.93807e+01, -7.50557e+00,
       -1.25410e+01, -3.97182e-01, -1.90590e+00, -1.76868e+01,
       -1.36247e+01, -1.73211e+00,  3.85518e+01,  3.01332e+02,
        2.73972e+02,  7.70106e+01, -1.14361e+01,  1.86369e+02,
        2.46977e+02,  8.31455e+01,  5.65317e+00,  1.27294e+01,
        1.00203e+01, -1.17402e+01,  2.77662e+01, -4.27329e+01,
       -3.56599e+01, -5.54958e+01, -1.32558e+01, -2.81824e+01,
        1.74667e+01, -7.21089e+00, -3.31630e+01,  1.59204e+01,
        3.75129e+00,  2.59298e+00,  1.26323e+01,  9.97398e+00,
        3.53983e+01,  6.54483e+01,  2.17958e+02,  2.49798e+02,
        1.95730e+02,  9.34401e+01,  1.19577e+02,  2.95815e+02,
        2.35680e+02,  3.95958e+01,  2.04519e+01,  2.23056e+01,
        1.04728e+01,  4.13824e+01, -7.01298e+00, -1.56546e+01,
       -1.09702e-01, -1.72753e-01,  3.00668e+01,  9.38009e+00,
       -9.37427e+00, -3.48430e+00, -1.08003e-01,  1.65164e+01,
       -2.92456e+01,  4.36536e+01,  1.77725e+01, -8.17721e+00,
        9.64735e+01,  1.95233e+02,  2.55611e+02,  2.49306e+02,
        2.35548e+02,  2.35231e+02,  2.49989e+02,  2.76956e+02,
        2.74830e+02,  2.48525e+02,  2.06920e+01, -1.45384e+01,
       -2.18008e+01,  9.40981e+00, -1.21378e+01,  6.55521e+00,
        1.50937e+01, -9.53465e+00,  1.30961e+01, -1.44901e+01,
        3.25047e+01,  3.11608e+01, -7.52132e+00,  3.32673e-01,
        3.02300e+01, -1.09056e+01,  4.43964e+01,  4.05151e+01,
        2.66073e+02,  2.51470e+02,  2.78116e+02,  2.46771e+02,
        2.49153e+02,  2.28313e+02,  1.75841e+02,  2.50854e+02,
        2.41130e+02,  1.76525e+02,  8.21311e+01, -1.76462e+01,
        9.79665e-01, -5.03212e+01, -3.38977e+01, -3.71078e+01,
       -3.43196e+01, -1.11945e+01, -4.05062e+01,  1.04211e+01,
        2.35890e+01,  1.46437e+00,  8.63347e+00, -3.71576e-01,
       -8.21729e+00, -1.24452e+01,  2.97835e+01,  9.99399e+01,
        2.26863e+02,  1.92009e+02,  1.56521e+02,  4.45907e+01,
        2.20701e+01, -1.00044e+01,  1.46868e+01,  2.55228e+02,
        2.17912e+02, -1.88956e+01, -2.26498e+01,  6.99782e+00,
        4.11815e+00, -2.00987e+01,  7.12709e+00,  1.81945e+00,
       -2.72156e+00,  9.83292e+00,  2.06573e+00,  7.48773e+00,
        4.96712e+00, -1.87271e+01,  4.35234e+00, -1.12788e+01,
        3.84628e+01, -1.40378e+01,  7.41670e-01, -1.06709e+01,
        5.39856e+00,  8.87864e+00, -5.78087e+00,  3.24254e+01,
        3.37935e+00,  2.86719e+00,  4.89003e+01,  2.56846e+02,
        2.53318e+02, -2.20285e+00,  4.06285e+01,  1.69746e+00,
       -6.53814e+00,  1.45286e+01,  4.93218e+00,  2.58272e+01,
        1.44117e+01, -3.02833e+01,  2.17766e+00,  1.12453e+01,
        2.45699e+01, -5.09358e+00,  3.80909e+01,  2.44727e+01,
       -5.77559e+00, -7.32648e+00, -3.51943e+01, -2.95459e+01,
       -1.25109e+00, -2.91962e+01,  2.20334e+01,  9.05947e+00,
        1.36272e+01, -4.89239e+00,  7.97818e+01,  2.45809e+02,
        2.73546e+02,  6.35912e+01,  1.20074e+01,  2.35910e+00,
       -7.86053e-01,  3.78986e+01,  1.37321e+01,  2.14848e+01,
       -1.61202e+01,  2.09294e+01,  9.21594e+00, -8.41174e+00,
       -2.13135e+01, -9.36809e+00,  7.07539e+00, -1.29072e+01,
        1.97590e+01,  9.11854e+00, -1.73637e+00,  2.76084e+01,
       -2.90821e+01, -4.88655e+00,  8.15024e-01,  4.01875e+00,
       -2.78145e+01,  2.27606e+00,  3.85835e+01,  2.55895e+02,
        2.46100e+02,  5.92354e+01,  1.00394e+01, -2.47892e+01,
       -1.09727e+01, -3.66841e+01,  2.02300e+01, -5.14658e+00,
        6.00640e+00,  2.49677e+01,  5.55689e+00,  4.93382e+01,
       -2.31855e+01, -4.19057e+01, -8.98082e+00, -7.06535e+00,
        1.63045e+01,  1.78965e+00,  2.69667e+01, -3.55541e+00,
        1.28437e+01,  1.73816e+01,  3.39852e+00,  7.77912e+00,
       -1.55787e+01, -4.74119e+01,  2.21011e+01,  2.06226e+02,
        2.67791e+02,  5.78124e+01,  1.16466e+01, -7.92537e+00,
        3.87494e+01, -9.24285e+00, -2.64023e+01,  8.46577e+00,
        2.46033e+01,  5.37295e+00, -2.23286e+01, -3.15278e+01,
        2.94299e+01, -1.86781e+01,  2.72944e+01, -1.43535e+01,
       -1.05593e+01, -3.29578e+01, -1.22854e+01,  8.07984e+00,
        3.90399e+00,  2.47789e+01,  1.54150e+01,  3.86269e+00,
        1.00901e+01, -2.09984e+01,  2.81658e+01,  2.38916e+02,
        2.53203e+02,  9.71723e+01, -1.40918e+01,  6.19697e+00,
        2.52356e+01,  6.32046e+00,  3.43176e+00, -1.40803e+01,
       -4.50321e+00, -2.11911e+01,  2.42583e+01,  3.18574e+01,
        3.24158e+01, -1.76879e+01,  5.04080e+01,  3.46684e+01,
       -9.57479e+00,  5.87925e+00, -1.27446e+01, -2.80142e+01,
        9.41069e+00, -1.63209e+00, -1.22236e+01,  2.76099e+00,
        2.37023e+01,  2.34141e+01, -1.38385e+01,  1.40432e+02,
        2.36219e+02,  3.01315e+01, -5.26530e+00,  4.46779e+00,
       -2.16320e+01, -1.47344e+00,  1.34955e+01, -2.41033e+01,
        1.36706e+01,  5.83552e-01, -1.54566e+00,  6.31859e-01,
        6.64487e+00, -1.72672e+01, -1.65740e+00,  1.20156e+01,
        5.30149e+00,  4.95793e+00,  1.43906e+01,  2.97536e+01,
       -2.87547e+01,  1.10865e+01, -3.50675e+01, -2.82799e-01,
        1.72594e+01,  7.98271e-01, -1.63411e+01,  3.80120e+00,
        7.58438e+00,  1.45247e+01, -3.70579e+00,  5.45871e+00,
        3.97541e+00,  2.03193e+01, -1.78142e+01,  3.45178e+01,
       -2.15872e+01,  1.39039e+01, -2.05152e+01,  3.00043e+01,
        4.30353e-01, -1.14632e+01,  2.76469e+00,  1.93526e+01,
        9.97542e+00, -2.63282e+01, -1.77778e+01,  7.07232e+00,
        3.44905e+01, -6.07877e+00,  3.30198e+01,  5.51341e-01,
       -3.05719e+01, -1.05455e+01, -2.89520e+01,  5.01119e+00,
        1.15149e+01,  4.65715e+01, -3.95418e+01, -2.13420e+01,
       -2.35853e+00,  2.30258e-01, -2.91287e+01, -8.13015e+00,
        7.83440e+00,  5.71136e+00,  7.19097e+00, -6.76401e+00,
       -2.70748e+01,  9.21056e+00, -1.35323e+01, -4.13871e+01,
        2.44157e+01, -2.49330e+01,  1.33811e+01,  8.26138e+00,
       -7.70961e+00, -2.92449e+00,  1.43662e+01,  3.78538e+00,
       -1.25971e+01,  1.95713e+01, -6.30019e+00, -7.03768e+01,
        2.76641e+01,  2.68981e+01,  2.58495e-01,  2.87334e+00,
       -2.50393e+01,  1.99615e+00,  1.61607e+01, -5.71961e+00,
        5.05598e+01,  2.13238e+01, -1.22810e+01, -1.42812e+01])
In [36]:
# plot the first image
plt.figure(figsize=(4,4))
plt.imshow(MNIST_train_noisy[0].reshape(28,28), cmap='gray');
In [37]:
# reshape the data
train_noisy = MNIST_train_noisy.reshape(-1, 28*28)
test_noisy = MNIST_test_noisy.reshape(-1, 28*28)

# reshape the original NMIST data sets
train_original = MNIST_train[:,1:].reshape(-1, 28*28)
test_original = MNIST_test[:, 1:].reshape(-1, 28*28)
In [38]:
# define a function to normalize the data
def normalize(X, mu=None,std=None, return_stats=False, revert=False):
    if revert:
        return X*std+mu
    else:
        X = X/255.
        if mu is not None and std is not None:
            Xbar = ((X-mu)/std)
        else:
            mu = np.mean(X, axis=0)
            std = np.std(X, axis=0)
            std_filled = std.copy()
            std_filled[std==0] = 1.
            Xbar = ((X-mu)/std_filled)
        if return_stats:
            return mu, std_filled
        else:
            return Xbar

standardization

In [39]:
mu_pca,std_pca = normalize(train_noisy, return_stats=True)
train_noisy_pca = normalize(train_noisy)
test_noisy_pca = normalize(test_noisy, mu_pca, std_pca)

mu_original, std_original = normalize(train_original, return_stats=True)
train_original_pca = normalize(train_original)
test_original_pca = normalize(test_original, mu=mu_original, std=std_original, revert=False)
In [40]:
print("The shapes of the reshaped training noisy data set is: ", train_noisy_pca.shape)
print("The shape of the reshaped and training original data set is: ", train_original_pca.shape)
The shapes of the reshaped training noisy data set is:  (6000, 784)
The shape of the reshaped and training original data set is:  (6000, 784)

1.2.1 PCA decomposition and show the variance explained as well as the first 10 pc¶

In [41]:
# perform PCA (from coding books)
def pca_function(X, m):
    """
    Return the X_pca matrix, the pcs, and corresponding eigen values.
    X: data set containing images.
    m: number of cpa.
    """
    # covariance matrix C
    C = 1.0/(len(X)-1) * np.dot(X.T, X)
    
    if m < len(X[0]):
        eigenvalues, eigenvectors = scipy.sparse.linalg.eigsh(C, m, which="LM", return_eigenvectors=True)
    else: 
        eigenvalues, eigenvectors =  scipy.linalg.eigh(C) 

    # sorting and eigenvalues from largest to smallest eigenvalue
    sorted_index = np.argsort(eigenvalues)[::-1]
    eigenvalues = eigenvalues[sorted_index]
    # v[:, i] is the ith e.vec, corresponding the ith e.value
    eigenvectors = eigenvectors[:,sorted_index]

    X_pca = X.dot(eigenvectors)
    return X_pca, eigenvectors, eigenvalues
In [42]:
m = 784  # the row number of the data
X_pca, eigenvectors, eigenvalues = pca_function(train_noisy_pca, m)
var_explained = [evalue/sum(eigenvalues) for evalue in eigenvalues]

# varaicne of pc as m increases
# for multiple pc, the variance explained is the sum of variance over total variance
m_var_explained = np.cumsum(var_explained)
In [43]:
# plot the variance explained against m
plt.figure(figsize=(10,6))
plt.plot(m_var_explained, color='blue', label='Variance Explained')
plt.xlabel('Number of Principle Components')
plt.ylabel('Variance Explained')
plt.title('Variance Explained by Principle Components')

# plot horizontal lines at which the value of variane explained reached 0.7, 0.8 and 0.9
plt.axhline(y=0.7, color='red', linestyle='--', label='70% exlained variance')
plt.axhline(y=0.8, color='green', linestyle='--', label='80% exlained variance')
plt.axhline(y=0.9, color='orange', linestyle='--', label='90% exlained variance')

# print the first value of m when the variance explained is 0.7, 0.8 and 0.9
print("The first value of m when the variance explained is 0.7 is:", np.where(m_var_explained>0.7)[0][0])
print("The first value of m when the variance explained is 0.8 is:", np.where(m_var_explained>0.8)[0][0])
print("The first value of m when the variance explained is 0.9 is:", np.where(m_var_explained>0.9)[0][0])
plt.legend(loc='best', shadow=True, fontsize='x-large')
plt.grid()
plt.show()
The first value of m when the variance explained is 0.7 is: 212
The first value of m when the variance explained is 0.8 is: 297
The first value of m when the variance explained is 0.9 is: 407
In [44]:
# visualize the first 10 principle components
fig, ax = plt.subplots(2, 5, figsize=(10, 5))
for i, ax in enumerate(ax.reshape(-1)):
    ax.imshow(eigenvectors[:, i].reshape(28, 28))
    ax.set_ylabel(f"PC{i+1}")

plt.tight_layout();

1.2.2 NMF decomposition and comparison between components with PCA's¶

In [45]:
train_noisy_nmf = (MNIST_train_noisy.reshape(-1, 28 * 28)) / 255.
test_noisy_nmf = (MNIST_test_noisy.reshape(-1, 28 * 28)) / 255.

# max_min normalization
def normalize_nmf(X, min=None, max=None, return_stats = False, revert=False):
    if revert:
        return X*(max-min)+min
    if return_stats:
        return X.min(), X.max()
    else:
        X_norm = (X - X.min()) / (X.max() - X.min())
        return X_norm
In [46]:
# normalization
train_noisy_nmf = normalize_nmf(train_noisy_nmf)
min_nmf, max_nmf = normalize_nmf(test_noisy_nmf, return_stats=True)
In [47]:
# define chi2 cost, same as the notebook
def cost(X, W, H):
    """Return the chi2 cost of the NMF decomposition."""
    # compute the difference between X and the dot product of W and H
    diff = X - np.dot(W, H) ## <-- EDIT THIS LINE
    chi2 = ((X*diff) * diff).sum() / (X.shape[0]*X.shape[1])

    return chi2
In [48]:
# Implement NMF
# construct placeholder matrices
np.random.seed(0)
m = 10
# m x k components matrix, usually interpreted as the coefficients
W = np.random.rand(train_noisy_nmf.shape[0], m)
# k x n matrix interpreted as the basis set(e.g. pixels)
H = np.random.rand(m, train_noisy_nmf.shape[1])

chi2 = []
n_iters = 200  # the number of iterations
eps = 1e-5  # check for convergence

# loop to find chi2 error against iterations  (about 12.5 mins on microsoft)
for i in range(n_iters):
    # update first on H
    H = H * ((W.T.dot(train_noisy_nmf)) / (W.T.dot(W.dot(H)))) ## <-- EDIT THIS LINE
    # the update on W
    W = W * ((train_noisy_nmf.dot(H.T)) / (W.dot(H.dot(H.T)))) ## <-- EDIT THIS LINE
    # compute the chi2 and append to list
    chi2.append(cost(train_noisy_nmf, W, H))

# check for convergence
for i in range(1, len(chi2)):
    if abs(chi2[i-1] - chi2[i]) < eps:
        print(f"Converged at iteration {i} and the difference is {chi2[i-1] - chi2[i]}")
        break
Converged at iteration 82 and the difference is 9.69217069603321e-06
In [49]:
print("The loss for NMF at m=100: ", chi2[99])
The loss for NMF at m=100:  0.007368975662134399
In [50]:
# plot the cost as a function of the number of iterations
plt.plot(chi2, label="Cost")
plt.xlabel("Number of Iterations")
plt.ylabel("chi2 Cost")
plt.title("Cost as a function of the number of iterations")
plt.legend(loc='best', shadow=False, fontsize='x-large')
plt.show()

Convergence check item¶

Particularly, epsilon is added to print the first iteration where the current is of the last ieration is less than eps difference than the last iteration. If the number is far from the chosen number of n_iter and that the cost is monotonically decreasing, we have reasons to believe that the number of iterations chosen is sufficient for converge. As the output and the plot shows, the fisrt iteration that meets the criterion is 82 and the cost curve is decreasing. Therefore, n_iter=200 is a suitable choice to make sure that the cost converges.

In [51]:
# visualize the m=10  components of NMF
fig, ax = plt.subplots(2, 5, figsize=(10, 5))
for i, ax in enumerate(ax.reshape(-1)):
    ax.imshow(H[i].reshape(28, 28))  # H[i] here represents the i-th positive eigenvector
    ax.set_ylabel(f"NMF{i+1}")

plt.tight_layout();
In [52]:
# visualize the first 10 principle components of PCA
fig, ax = plt.subplots(2, 5, figsize=(10, 5))
for i, ax in enumerate(ax.reshape(-1)):
    ax.imshow(eigenvectors[:, i].reshape(28, 28))
    ax.set_ylabel(f"PC{i+1}")

plt.tight_layout();

Explanation:¶

The components of PCA and NMF are so different. This arises from their different natures of the way of decomposition.

  • PCA:

The components of PCA are called 'eigenfaces', as each principle component of PCA is a modified version of the images. When combining together linearly, a lot of cancellations will be involved.

  • NMF:

The components of NMF have a direct visual meaning: it contains local and sparse features of the images.

In all, each imshow shown for PCA components is a global representation of an image, but for NMF, each imshow is a partial and sparse representation of the image.

(references: https://www.nature.com/articles/44565)

1.2.3 Reconstruct images and compare results.¶

In [53]:
# define a function to compute the mse score
def mse_score(X_either, X_reconstructed):
    """Return the mean square error between the reconstructed and corrupted or corrupted images."""
    return np.mean(np.square(X_either - X_reconstructed))
In [54]:
# train on PCA
m = 100
X_pca, train_eigenvectors, train_eigenvalues = pca_function(train_noisy_pca, m)

# the reconstructed images by pca
X_reconstructed_pca = test_noisy_pca @ train_eigenvectors @ train_eigenvectors.T  # formula from notes

# train on NMF
# construct placeholder matrices
np.random.seed(0)
W = np.random.rand(train_noisy_nmf.shape[0], m)
H = np.random.rand(m, train_noisy_nmf.shape[1])

chi2 = []
n_iters = 200  # the number of iterations

# loop over  (about 20 mins on microsoft)
for i in range(n_iters):
    # update first on H
    H = H * ((W.T.dot(train_noisy_nmf)) / (W.T.dot(W.dot(H)))) 
    W = W * ((train_noisy_nmf.dot(H.T)) / (W.dot(H.dot(H.T))))
    chi2.append(cost(train_noisy_nmf, W, H))

# reconstructed images by NMF (formula derived from notes: rows of H are analogue to principle components in PCA)
X_reconstructed_nmf = test_noisy_nmf @ H.T @ H

revert the images back to its original scale

In [55]:
X_reconstructed_pca_revert = normalize(X_reconstructed_pca,mu_pca,std_pca,revert=True)
X_reconstructed_nmf_revert = normalize_nmf(X_reconstructed_nmf,min_nmf,max_nmf,revert=True)
In [56]:
np.random.seed(0)
# randomly choose an image from the noisy test data set
random_image = np.random.randint(0, test_noisy_pca.shape[0])

# plot the noisy, the reconsrtucted and the original image in a row
plt.figure(figsize=(10, 6))

plt.subplot(2, 2, 1)
plt.imshow(test_noisy[random_image].reshape(28, 28))
plt.title("Noisy Image")
# plot the reconstructed image by PCA
plt.subplot(2, 2, 2)
plt.imshow(X_reconstructed_pca_revert[random_image].reshape(28, 28))
plt.title("Reconstructed Image by PCA")
# plot the reconstructed image by NMF
plt.subplot(2, 2, 3)
plt.imshow(X_reconstructed_nmf_revert[random_image].reshape(28, 28))
plt.title("Reconstructed Image by NMF")
# plot the original image
plt.subplot(2, 2, 4)
plt.imshow(test_original[random_image].reshape(28, 28))
plt.title("Original Image")

plt.tight_layout()
plt.show()

Analysis:¶

In terms of denoising, NMF performs better, as the feature distribution is more like the original image. But in terms of visualizing images, PCA perfroms better, as it's much clearer than the reconstructed image by NMF.

Explanation: NMF has the following properties

  • by the nature of basis and coefficient components of NMF, H and W contain a large proportion of vanishing coefficients. So both W and H(features) are sparse.
  • Also, each component is a local representation of an image.
  • Components of H are non-negative.

The reconstructed images are built using a linear combination of different local parts of the image(the non-negativeness of H only gives addition). This means that not all available local features are used in the linear combination. Therefore, due to the properties mentioned above, sparse addition of features is giving much less noise but also is less likely to give as many features.

On the contrary, for PCA:

  • each componet of PCA is an 'eigen image', representing a modified version of the whole image.
  • the components of eigenvalues and eigenvectors could be both positive and negative, meaning that the directions of variance of features can vary vastly.

The reconstructed images are built on a linear combination, not only with different eigen images, but also both additions and subtractions are allowed. The use of the whole image(rather than sparse and local features) in reconstruction and the complexities of additions and subtractions involved mean that the recontructed image will combine more features as well as noise, thus giving an better-visualized but more noisy image.

(ref again from 1.2.2: https://www.nature.com/articles/44565)

1.2.4 MSE between denoised data and test_original and test_noisy data¶

In [57]:
# define a function to compute the mse score
def mse_score(X_either, X_reconstructed):
    """Return the mean square error between the reconstructed and corrupted or corrupted images."""
    return np.mean(np.square(X_either - X_reconstructed))
In [58]:
np.random.seed(10)
m_range = np.arange(5, 601, 5)  # pc values
random_image = np.random.randint(0, test_noisy_pca.shape[0])

with_uncorrupted_mse_score_lis = []
with_corrupted_mse_score_lis = []
pca_examples_holder = []

for m in m_range:
    _, train_eigenvectors, _ = pca_function(train_noisy_pca, m)
    X_reconstructed_pca = test_noisy_pca @ train_eigenvectors @ train_eigenvectors.T
    X_reconstructed_pca_revert = normalize(X_reconstructed_pca, mu_pca, std_pca, revert=True)
    
    with_uncorrupted_mse_score = mse_score(test_original/255, X_reconstructed_pca_revert)
    with_corrupted_mse_score = mse_score(test_noisy/255, X_reconstructed_pca_revert)

    with_uncorrupted_mse_score_lis.append(with_uncorrupted_mse_score)
    with_corrupted_mse_score_lis.append(with_corrupted_mse_score)

    if m in [10, 40, 100, 200, 400, 600]:
        X_reconstructed_pca_revert = normalize(X_reconstructed_pca,mu_pca,std_pca,revert=True)
        pca_examples_holder.append(X_reconstructed_pca_revert[random_image])
In [59]:
# plot the MSE
plt.figure(figsize=(10, 6))
plt.plot(m_range[0:80] , with_uncorrupted_mse_score_lis[0:80], label="MSE of reconstructed with denoised images")
plt.plot(m_range[0:80], with_corrupted_mse_score_lis[0:80], label="MSE of reconstructed with noisy images")
plt.axvline(100, color="red", linestyle='--')
plt.xlabel('number of principle components')
plt.ylabel('MSE')
plt.title('MSE of images against number of principle components')
plt.legend(loc='best', shadow=False, fontsize='x-large')
plt.show()
In [60]:
# plot the example figure
plt.figure(figsize=(10, 10))
plt.subplot(2,4,1)
plt.imshow(test_original[random_image].reshape(28,28))
plt.ylabel('original test image')

plt.subplot(2,4,2)
plt.imshow(test_noisy[random_image].reshape(28,28))
plt.ylabel('noisy test image')

plt.subplot(2,4,3)
plt.imshow(pca_examples_holder[0].reshape(28,28))
plt.ylabel('m = 10')

plt.subplot(2,4,4)
plt.imshow(pca_examples_holder[1].reshape(28,28))
plt.ylabel('m = 40')

plt.subplot(2,4,5)
plt.imshow(pca_examples_holder[2].reshape(28,28))
plt.ylabel('m = 100')

plt.subplot(2,4,6)
plt.imshow(pca_examples_holder[3].reshape(28,28))
plt.ylabel('m = 200')

plt.subplot(2,4,7)
plt.imshow(pca_examples_holder[4].reshape(28,28))
plt.ylabel('m = 400')

plt.subplot(2,4,8)
plt.imshow(pca_examples_holder[5].reshape(28,28))
plt.ylabel('m = 600')

plt.suptitle("Reconstructed images for test noisy data at different m")
plt.tight_layout()
plt.show()

Explanation:¶

  • The MSE decreases fast at about first m=100 components (dotted-red line). This is because, PCA factorizes out principle components in decending order and eigenvectors(which statistically represent variance explained) corresponding to greater eigenvalues bear with more represented features in the reconstructed images. Therefore MSE decreases rapidly at start.

  • After m=100, the MSE with test data still decreases while MSE with original data remains almost stable. This means that the added principle components help little in denoising the noisy images and thus adding more components is pointless in decreasing MSE. But for test noisy data, the added components are adding more information (but it's just noise) about images to the basis of components and give a decreasing MSE.

The trend above(orange) can be corresponded in the reconstructed test noisy images plotted for different $m$s:

  • Images become clear rapidly from m=10 to m=40 and m=100.

  • Images change little, either the background or the number part, from m=100 to m=400, because adding components are just adding more noise.

In all, none of the reconstructed images resemble the original image for the background part, which means that the PCA does not perform well in denoising data. But the reconstructed images are closer to the noisy data, meaning that PCA components fit well to unseen noisy data and is good at visualizing noisy images.

As m further increases, the MSE with reconstructed images and test noisy images are expected to further decrease (which is verified below). This trend can also be seen from the last plot.

In [61]:
# verify the MSE after 400
plt.figure(figsize=(10, 6))
plt.plot(m_range, with_corrupted_mse_score_lis, color='orange', label="MSE of reconstructed with test noisy images")
plt.xlabel('number of principle components')
plt.ylabel('MSE')
plt.title('MSE of images against number of principle components')
plt.legend(loc='best', shadow=False, fontsize='x-large')
plt.show()

1.3 Gaussian Mixture Models (20 marks)¶

In this section, digits' images are reconstructed based on 5 principle components and then are being clustered by GMM models, with 10, 5 and 8 hidden components respectively.

For each of 3 models:

  • the best-fitting cluster indexes are computed, by involving log likelihoods.
  • to measure the uncertainty of clustering, space spanned by the first 2 principle components are plotted as well as for each digit class.

Intuitively, number of hidden components should be at least the number of digits, but for similar types of images, especially those only built on 5 pc, GMM with less hidden components are likely to perform better. The discussion will be included in the section.

Note that in EM algorithm, a convergence check is set up after each EM step to stop the next iteration if the previous mu and sigma are close to avoid overflow.

1.3.1 GMM on probablistic clustering with 10 hidden components¶

In [62]:
class GMModel:
    """Define Gaussian Mixture Model class"""
    """:param dim: number of mixture components"""
    """:param weights: mixture weights"""
    """:param mu: mixture component means for each cluster"""
    """:param sigma: mixture component covariance matrix for each cluster"""
    
    def __init__(self, X, dim):
        """Initialises parameters through random split of the data"""
        
        self.dim = dim  # number of k

        # initial weights/ P(Ci=j)/ prior
        self.phi = np.full(shape=self.dim, fill_value=1/self.dim)  # <- fill the array of shape with values fill_value

        # initial weights/ P(Xi/Ci=j)/ likelihood
        self.weights = np.full(shape=X.shape, fill_value=1/self.dim)
        
        n, m = X.shape 
        # as a generator of self.mu
        random_row = np.random.randint(low=0, high=n, size=self.dim)  # <- could be repeated

        # initial value of mean of k Gaussians and sigmas
        self.mu = [  X[row_index,:] for row_index in random_row ]
        self.sigma = [ np.cov(X.T) for _ in range(self.dim) ] 
In [63]:
def cluster_probabilities(gmm, X):
    """Predicts cluster probability for each data point."""
    
    n, m = X.shape
    # l_ij = p(x_i|theta_j)
    likelihood = np.zeros((n, gmm.dim))

    for i in range(gmm.dim):
        # likelihood of data belonging to i-th cluster 
        distribution =  multivariate_normal(mean=gmm.mu[i], cov=gmm.sigma[i]) # <- from scipy
        likelihood[:,i] = distribution.pdf(X)

        numerator = likelihood * gmm.phi
        denominator = numerator.sum(axis=1)[:, np.newaxis] # axis=1: col sum: across diff k
        weights = numerator / denominator

    return weights
In [64]:
def predict(gmm, X):
    """Performs hard clustering"""
    weights = cluster_probabilities(gmm, X)
    return  np.argmax(weights, axis=1)
In [65]:
# implement EM algorithm
def fitStep(gmm, X):
    """Performs an EM step by updating all parameters"""
    
    # E-Step: update weights and phi holding mu and sigma constant: down in "/total_weight"
    # M-Step: update mu and sigma holding pi and weights constant
    weights = cluster_probabilities(gmm,X)
    gmm.phi = weights.mean(axis=0)  # prior
        
    for i in range(gmm.dim):
        weight = weights[:, [i]]
        total_weight = weight.sum()

        gmm.mu[i] = (X * weight).sum(axis=0) / total_weight
        # bias=True: normalize by num(observations)
        gmm.sigma[i] =  np.cov(X.T, aweights=(weight/total_weight).flatten(), bias=True)
In [66]:
# train the model with EM. But with a convergence check
def train_gmm(X, n_components, n_iters=1000, eps=1e-8):
    gmm = GMModel(X,n_components)
    prev_mu = gmm.mu.copy()
    prev_std = gmm.sigma.copy()
    for i in range(n_iters):
        fitStep(gmm,X)
        if np.allclose(prev_mu, gmm.mu) and np.allclose(prev_std, gmm.sigma):
            print(f"Converged at iteration {i}!")
            break
        prev_mu = gmm.mu.copy()
        prev_std = gmm.sigma.copy()
    return gmm
In [67]:
m = 5

np.random.seed(2)
X_pca_131, eigenvectors_131, eigenvalues_131 = pca_function(train_original_pca[0:1000], m)  
gmm = GMModel(X_pca_131, 10)
gmm = train_gmm(X_pca_131, n_components=10, n_iters=1000, eps=1e-8)

# hard clustering
cluster_labels = predict(gmm, X_pca_131)
# visualize the space spanned
plt.scatter(X_pca_131[:, 0], X_pca_131[:, 1], c=cluster_labels, cmap="tab10")
plt.xlabel('PC1')
plt.ylabel('PC2')
plt.colorbar(label='Cluster Labels')
plt.title('GMM Clustering on MNIST_train on PC1 and PC2, using cluster labels')
plt.show()
Converged at iteration 378!

1.3.2 Find the best-fitting cluster index and discuss the result¶

In [68]:
class_labels = MNIST_train[:1000, 0]

# visualize the space spanned
plt.scatter(X_pca_131[:, 0], X_pca_131[:, 1], c=class_labels, cmap="tab10")
plt.xlabel('PC1')
plt.ylabel('PC2')
plt.colorbar(label='Class Labels')
plt.title('GMM Clustering on MNIST_train on PC1 and PC2, using class labels.')
plt.show()
In [69]:
def log_probs(gmm, X):
    """
    Return a matrix of shape (n_sample, n_components), 
        each elemt is the log probability of X_i belonging to the kth mixture component.
    """

    n,m = X.shape 
    log_prob = np.zeros((n, gmm.dim))
    
    for i in range(gmm.dim):
        distribution = multivariate_normal(mean=gmm.mu[i],cov=gmm.sigma[i])
        # log probability for each x at component i
        log_prob[:,i]=distribution.logpdf(X)

    return log_prob
In [70]:
def label_cluster_mapping(gmm, X, class_labels):
    """
    Return a map of class labels to best_fitting cluster index.
    Args:
        gmm: Gaussian mixture model.
        X: data set
        class_labels: given class labels.
    """

    all_log_likelihoods = log_probs(gmm, X)

    class_log_likelihood = []
    for i in np.unique(class_labels):
        label_indicator = (class_labels == i)
        class_log_likelihood.append(np.sum(all_log_likelihoods[label_indicator],axis = 0))
    
    cluster_labels = np.argmax(class_log_likelihood, axis=1)

    return {class_label: cluster_label for class_label, cluster_label in zip(np.unique(class_labels), cluster_labels)}
In [71]:
# print the map
class_cluster_map = label_cluster_mapping(gmm, X_pca_131, class_labels)
print("The label cluster mapping is: ", class_cluster_map)
The label cluster mapping is:  {0: 5, 1: 0, 2: 4, 3: 1, 4: 3, 5: 1, 6: 6, 7: 3, 8: 1, 9: 3}

Analysis:¶

The keys of the mapping dictionary are the class label and the values of the mapping are best-fitting index of normal distributions:

$\cdot$ One interesting thing to notice is that digit 3, 5 and 8 are both mapped to label 1. In reality of hand-writting, these 2 digits do look alike for roughly their shapes.

$\cdot$ In addition, digit 4, 7 and 9 are clustered to label 3.

  • Reason: This GMM model assumes each digit is driven by a corresponding latent distribution. But we only use 5 principle components to represent the images, and from section 1.2, we know that images reconstructed by m=5 are clear in the background but blurred on the lighter part -- where the digits manifest. This tells the features captured are insufficient to represent distinct features of images of digits. And digits with similar rough shapes are more likely to be clustered as the identical digit.

Analysis wrt the plot above is included in the explanation in 1.3.3

1.3.3 Uncertainty of clusteirng visualized on classes¶

In [72]:
# log probs matrix: containing all the log-probs of data point, X_i to the kth mixture components
log_cluster_prob = log_probs(gmm, X_pca_131)
print(log_cluster_prob)
[[ -16.97570347  -18.7475566   -14.76573705 ...  -19.31242479
  -282.41979186  -11.23842018]
 [-114.66542145  -15.01411727  -19.51109961 ...  -32.36062497
  -200.37780552  -44.10846414]
 [-101.16633094  -15.31546101  -11.40404744 ...  -33.75332819
  -170.75535242  -57.92117636]
 ...
 [-148.20503636  -13.40809482  -17.53956278 ...  -26.11280143
  -132.32450628  -47.96001604]
 [-698.52436219  -12.81893972  -52.78271987 ...  -30.4427656
  -110.48773354  -63.4990667 ]
 [-226.9374572   -19.090202    -15.9754111  ...  -12.50942501
  -110.07120331  -34.46886479]]
In [73]:
# replot from 1.3.2
# visualize the space spanned
plt.scatter(X_pca_131[:, 0], X_pca_131[:, 1], c=cluster_labels, cmap="tab10")
plt.xlabel('PC1')
plt.ylabel('PC2')
plt.colorbar(label='Cluster Labels')
plt.title('GMM Clustering on MNIST_train on PC1 and PC2, using cluster labels')
plt.show()
In [74]:
fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(30, 10), sharex=True, sharey=True)

for i, ax in enumerate(axes.flatten()):
    label_indicator = (class_labels == i)
    cluster_probs_total = cluster_probabilities(gmm, X_pca_131[label_indicator, :]) 
    cluster_probs = cluster_probs_total[:, class_cluster_map[i]]

    cmap = plt.cm.get_cmap('viridis')
    norm = plt.Normalize(vmin=cluster_probs.min(), vmax=cluster_probs.max())
    sm = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)
    colors = sm.to_rgba(cluster_probs)

    ax.scatter(X_pca_131[label_indicator, 0], X_pca_131[label_indicator, 1], c=colors)
    ax.set_title('Digit {}'.format(i))
    
fig.colorbar(sm, ax=axes.ravel().tolist(), shrink=0.75, label='Cluster Probabilities')
plt.suptitle("Visualization of Each Class, Colored By the Cluster Probability of the Best-fitting Cluster")
plt.show()

Analysis:¶

The plot above gives the uncertainty of the class points assigned to the best-fitting cluster. The brighter the color, the more likely that the digits are assigned to the best-fitting cluster, based on 5 principle components.

To view the uncertainty:

  • First take a look on how the cluster probabilities vary within each class. If the distribution of the cluster probabilities are evenly distributed, it means more uncertainty in the clustering as such distributions makes it hard to decide the dominant digit and therefore lead to unstableness. Moreover, we want to have as many yellow dots as possible to determine the dominant digit.

By this, GMM model with 10 hidden components has a great uncertainty: Apart from digit 0, 1, 2 and 3, other digit classes either has a distribution with a wide range of values of cluster probabilities or there are many dark-color points which makes it hard to determine the dominant label.

  • Second, we can take a look on how well-separated the different clusters are in the scatter plot above. If different clusters are well-separated, then it suggests the clustering is more certain and the cluster probability is low. On the other hand, if the a region of plot is densely stacked with clusters, it has smaller uncertainty ie. higher cluster probability.

From this, the clusters are highly stacked on the lower right part of the plot, giving low cluster probability to the corresponding areas of each class, eg. for the green dots in the scatter plot, it has orange, red and pink stacking on it, giving a dark color of cluster probability at the corresponding lower right part in the plot of digit 8. For clyster 5, the brown dots, it's individually and densely distributed around the area $(-25, 15)\cdot (5, 15)$. Therefore, it gives yellow dots to the corresponding region in the plot of digit 0.

1.3.4 Retrain GMM with 5 nad 8 digits¶

In [75]:
def train_map_plots(hidden_components, n_its = 1000):
    # train
    gmm = train_gmm(X_pca_131, hidden_components)

    # map
    class_labels = MNIST_train[:1000, 0]
    label_cluster_map = label_cluster_mapping(gmm, X_pca_131, class_labels)
    print("Label-Cluster Index Map: ", label_cluster_map)

    # plot
    fig, (ax0, ax1) = plt.subplots(nrows=1, ncols=2, figsize=(30, 10))
    # color by clusters
    cluster_labels = predict(gmm, X_pca_131)
    sc0 = ax0.scatter(X_pca_131[:, 0], X_pca_131[:, 1], c=cluster_labels, cmap="tab10")
    ax0.set_xlabel('PC1')
    ax0.set_ylabel('PC2')
    ax0.set_title(f'Coloring Cluster Labels, for {hidden_components} components')
    cbar = fig.colorbar(sc0, ax=ax0)
    cbar.ax.set_ylabel('Cluster Labels')
    # color by class labels
    sc1 = ax1.scatter(X_pca_131[:, 0], X_pca_131[:, 1], c=class_labels, cmap="tab10")
    ax1.set_xlabel('PC1')
    ax1.set_ylabel('PC2')
    ax1.set_title(f'Coloring Class Labels, for {hidden_components} components')
    cbar = fig.colorbar(sc1, ax=ax1)
    cbar.ax.set_ylabel('Class Labels')

    plt.show()

    # plot each individual class
    fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(30, 10), sharex=True, sharey=True)

    for i, ax in enumerate(axes.flatten()):
        mask = (class_labels == i)
        cluster_probs_total = cluster_probabilities(gmm, X_pca_131[mask, :]) 
        cluster_probs = cluster_probs_total[:, label_cluster_map[i]]

        cmap = plt.cm.get_cmap('viridis')
        norm = plt.Normalize(vmin=cluster_probs.min(), vmax=cluster_probs.max())
        sm = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)
        colors = sm.to_rgba(cluster_probs)

        ax.scatter(X_pca_131[mask, 0], X_pca_131[mask, 1], c=colors)
        ax.set_title('Digit {}'.format(i))

    plt.suptitle("Visualization of Each Class of Digit, Colored By the Cluster Probability of the Best-fitting Cluster Label")
    fig.colorbar(sm, ax=axes.ravel().tolist(), shrink=0.75, label='Cluster Probabilities')
    plt.show()

Retrain with 5 hidden components

In [76]:
hidden_components = 5
train_map_plots(hidden_components, n_its = 1000)
Converged at iteration 179!
Label-Cluster Index Map:  {0: 2, 1: 4, 2: 3, 3: 3, 4: 1, 5: 2, 6: 3, 7: 1, 8: 2, 9: 3}

Retrain with 8 components

In [77]:
hidden_components = 8
train_map_plots(hidden_components, n_its = 1000)
Converged at iteration 305!
Label-Cluster Index Map:  {0: 5, 1: 6, 2: 3, 3: 2, 4: 3, 5: 5, 6: 3, 7: 1, 8: 7, 9: 3}

Analysis:¶

Usually, the number of hidden components should be be at least the number of classes (in this case 10), by assuming each digit follows a distinct distribution. But if elements of classes are 'similar', less hidden components could lead to a good clustering outcome. From 1.2.4, we see that the reconstructed images, even at m=10, are super-vague, not to mention we only used m=5 in this section. In this case, the recontructed digit images are really blurred (also shown below) and are likely to bring highly-similar patterns and information to the clustering machine.

  • Comparison for GMM with 5, 8, 10 hidden components
5 8 10
digits clustered to the same distribution: 479 to 3, 358 to 1 2369 to 3, 05 to 2, 47 to 1 2469 to 3, 0 to 55
numerous yellow points (Y/N) YNN, NNN YNNN, YN, NY NNNN, Y
  • In terms of distributions of clusters in the scatter plot, it can be seen obviously that for reduced hidden components to 5, the clusters are more well-separated, showing a reduced uncertainty.

  • One thing in common is that cluster probabilities are high for clusters densely overlapped. This is bacause points in highly dense areas mean samples are similar in these 2 dimensions. This will make clustering more difficult and lead to more uncertainty.

compliment to the analysis

In [78]:
# to show: reconstructed images for 5 principle components is highly blurred
test_constructed = test_original_pca @ eigenvectors_131 @ eigenvectors_131.T
plt.imshow(test_constructed[0].reshape(28,28))
plt.show()

Task 2: Clustering and graph-based analysis (35 marks)¶

In [79]:
# import and do the data exploration
gene_data = pd.read_csv("gene_expression_data.csv", decimal=",")  # as a pandas data frame
# expressions
gene_expression = gene_data[gene_data.columns[:-1]].astype(float)
# type (labels)
gene_type = gene_data[gene_data.columns[-1]]
display(gene_data.head(5))

print("The type of the gene_data, espressions and types:", type(gene_data))
print("The type of espressions:", type(gene_expression))
print("The type of types:", type(gene_type))
print("The shape of the data is:", gene_data.shape)
Gene 0 Gene 1 Gene 2 Gene 3 Gene 4 Gene 5 Gene 6 Gene 7 Gene 8 Gene 9 ... Gene 86 Gene 87 Gene 88 Gene 89 Gene 90 Gene 91 Gene 92 Gene 93 Gene 94 Type
0 9.79608829288 0.591870870063 0.591870870063 0.0 11.4205708246 13.4537593388 4.41184651859 5.4123344238 10.7716132653 10.2256653627 ... 5.97436869818 8.08651310808 12.7277503154 15.2057168981 6.43811649662 6.412576621 0.0 6.81472985199 13.6181445741 PRAD
1 10.0704698332 0.0 0.0 0.0 13.085671621 14.5318626848 10.4622977655 9.83292639746 13.5203117438 13.9680457391 ... 0.0 0.0 11.1972044043 12.9939325894 10.8007462415 10.7498107801 0.0 11.445609809 0.0 LUAD
2 8.97091978401 0.0 0.452595434703 0.0 8.26311893627 9.75490753946 8.96454880885 9.94811313134 8.6937726804 8.7761105697 ... 3.90715987355 5.32410132 11.4870662364 13.3805963452 6.65623606704 10.2097335917 0.0 7.7488301787 12.759975541 PRAD
3 8.52461614952 1.03941918313 0.434881719407 0.0 10.7985204031 12.2630197299 7.44069478818 8.06234301354 8.80208333007 9.23748723755 ... 4.29608291884 6.95974697548 12.974638649 14.8918121772 6.03072451189 7.31564773826 0.434881719407 7.11792356209 12.3532764196 PRAD
4 8.04723845046 0.0 0.0 0.360982241369 12.2830101953 14.033758513 8.71918001723 8.83147193285 8.46207277429 8.21120206054 ... 0.0 0.0 11.3372372064 13.3900614488 5.98959318494 8.35967050637 0.0 6.32754545866 0.0 BRCA

5 rows × 96 columns

The type of the gene_data, espressions and types: <class 'pandas.core.frame.DataFrame'>
The type of espressions: <class 'pandas.core.frame.DataFrame'>
The type of types: <class 'pandas.core.series.Series'>
The shape of the data is: (800, 96)
In [80]:
# standardize the data
def standardize(X):
    """Return a standardized dataset."""
    if type(X) != np.ndarray:
        X = X.to_numpy()

    mean = np.mean(X, 0)
    std = np.std(X, 0)
    return (X - mean)/std

2.1 Clustering (15 marks)¶

In this section, k-means clustering is implemeted for the gene expressions, where k is treated as a hyperparameter to tune. The highest Calinski-Harabasz index is used as a measure to find the optimal k and consistency of the clustering is assessed by homogeneity score.

Two things to notice:

  • In the initialization of labels in k-means algorithm, in order to avoid the situation where none of a sample is assigned to a certain cluster, we first randomly choose n-k samples from range 0 to k, and then add the labels with additional k labels, each of which is from a cluster. This guarantees at least one intially assigned sample for each cluster.

  • In computing $a_{ck}$ in homogeneity score, we should avoid it to be 0 as logarithm will be taken on it. For this case, simply skip it ie. add zero to the total sum, which will result into a slight underestimation in the true value of the total sum, compared to the result given by sklearn.

2.1.1 k-means clustering and find the optimal k¶

In [81]:
gene_expression = standardize(gene_expression)
In [82]:
def compute_centroids(X, k):
    """
    Return the centroid for each cluster according to assignments.
    Args:
        X: data
        k: number of clusters
    """

    if type(X) != np.ndarray:
        X = X.to_numpy()

    n_samples, n_features = X.shape
    
    # assign labels, each time with different initialization, but use a 'trick' to unsure at least 1 point in each cluster
    labels = np.random.randint(low=0, high=k, size=n_samples-k)
    labels = np.concatenate((labels, np.arange(k)))
    random.shuffle(labels)

    # initialization
    centroids = np.zeros((k, n_features))

    # compute the centroids of points with the same assigned label
    for i in range(k):
        centroids[i] = np.mean(X[labels==i], axis=0)

    return centroids, labels
In [83]:
# k-means algorithm, from coding book
def k_clustering(X, k, max_iter, message=False):
    """
    Return the updated centroids and labels for fixed k.
    Args:
    X: data set
    k: number of clusters
    max_iter: maximum number of iterations
    """
    if type(X) != np.ndarray:
        X = X.to_numpy()
    
    difference = 0
    new_labels = np.zeros(len(X))
    # initialize centroids: each time different outcomes
    centroids, labels = compute_centroids(X, k)

    for i in range(max_iter):
        if message==True:
            print('Iteration:', i)
        # distances: between data points and centroids
        distances = np.array([np.linalg.norm(X - c, axis=1) for c in centroids])
        # new_labels: computed by finding centroid with minimal distance
        new_labels = np.argmin(distances, axis=0)

        if (labels==new_labels).all():
            # labels unchanged
            labels = new_labels
            if message==True:
                print('Labels unchanged! Terminating k-means.')
            break
        else:
            # labels changed
            # difference: percentage of changed labels
            difference = np.mean(labels!=new_labels) 
            if message==True:
                print('%4f%% labels changed' % (difference * 100))
            labels = new_labels
            for c in range(k):
                # update centroids by taking the mean over associated data points
                if (labels == c).any():
                    centroids[c] = np.mean(X[labels==c], axis=0)

    return (centroids, labels)
In [84]:
# Calinski-Harabasz index
def bcsm(X, k, centroids, labels):
    """
    Return bcsm value.
    Args:
        X: gene_expression data set
        k: number of clusters
        centroids: the updated centroids after k_clustering
        labels: the updates labels after k_clustering
    """
    if type(X) != np.ndarray:
        X = X.to_numpy()

    # an array: number of points in each cluster
    n_i = np.array([sum(labels==i) for i in range(k)])
    # the centroid of all data points
    z_tot = np.mean(X, axis=0)
    # square distance between cluster centers and the total centroid
    dis_sqr = np.array([np.linalg.norm(centroids[i] - z_tot)**2 for i in range(k)])
    
    # return bcsm
    return sum(n_i * dis_sqr)

def wcsm(X, k, centroids, labels):
    """
    Return WCSM value.
    Args:
        X: gene_expression data set
        k: number of clusters
        centroids: the updated centroids after k_clustering
        labels: the updates labels after k_clustering
    """
    if type(X) != np.ndarray:
        X = X.to_numpy()

    wcsm_value = 0
    for i in range(k):
        wcsm_value += sum(np.array([np.linalg.norm(X[labels==i] - centroids[i])**2]))
    return wcsm_value

def CH_k(X, k, centroids, labels):
    """
    Return the wcsm quantity measure to assess the clustering of the data points.
    The greater, the better the classification.
    Args:
        X: gene_expression data set
        k: number of clusters
        centroids: the updated centroids after k_clustering
        labels: the updates labels after k_clustering
    """
    if type(X) != np.ndarray:
        X = X.to_numpy()

    bcsm_value = bcsm(X, k, centroids, labels)
    wcsm_value = wcsm(X, k, centroids, labels)

    return bcsm_value*(X.shape[0]-k) / ((k-1)*wcsm_value)
    
In [85]:
# define elbow function
def norm_within_cluster_dis(X, cluster_labels):
    """
    Return the w_c cost.
    Args:
        X: data set
        cluster_labels: updated cluster labels
        k: number of clusters.
    """
    w_c = 0
    for i in np.unique(cluster_labels):
        # extract the corresponding elmts
        cluster_elemts = X[cluster_labels==i, :]
        for j in cluster_elemts:
            for k in cluster_elemts:
                w_c += 0.5* np.linalg.norm(j-k)**2 / len(cluster_elemts)
    return w_c
In [86]:
# run for different k and 5 initializations for each k

np.random.seed(44)
k_range = 16  # k from 2 to k_range, as 'k-1' on the denominator
k_range_ch_index = np.array([])

for k_value in range(2, k_range):
    ch_k = np.array([])
    holding_l = []
    # 5 different initializations for each k
    for _ in range(5):
        up_centroids, up_labels = k_clustering(gene_expression, k_value, max_iter=70, message=True)
        # checked all converged with max_iter
        ch_index = CH_k(gene_expression, k_value, up_centroids, up_labels)
        ch_k = np.append(ch_k, ch_index)
    k_range_ch_index = np.append(k_range_ch_index, np.mean(ch_k))

# print the optimal k corresponding ch index
print(f"The optimal k and CH index are: {2 + np.argmax(k_range_ch_index)} and {np.max(k_range_ch_index)}")
Iteration: 0
46.250000% labels changed
Iteration: 1
1.500000% labels changed
Iteration: 2
0.500000% labels changed
Iteration: 3
Labels unchanged! Terminating k-means.
Iteration: 0
46.375000% labels changed
Iteration: 1
6.750000% labels changed
Iteration: 2
3.625000% labels changed
Iteration: 3
1.375000% labels changed
Iteration: 4
0.500000% labels changed
Iteration: 5
Labels unchanged! Terminating k-means.
Iteration: 0
46.750000% labels changed
Iteration: 1
19.625000% labels changed
Iteration: 2
3.750000% labels changed
Iteration: 3
1.250000% labels changed
Iteration: 4
0.375000% labels changed
Iteration: 5
Labels unchanged! Terminating k-means.
Iteration: 0
45.125000% labels changed
Iteration: 1
7.250000% labels changed
Iteration: 2
7.000000% labels changed
Iteration: 3
3.000000% labels changed
Iteration: 4
0.750000% labels changed
Iteration: 5
0.375000% labels changed
Iteration: 6
Labels unchanged! Terminating k-means.
Iteration: 0
46.500000% labels changed
Iteration: 1
9.000000% labels changed
Iteration: 2
2.375000% labels changed
Iteration: 3
1.375000% labels changed
Iteration: 4
1.000000% labels changed
Iteration: 5
1.125000% labels changed
Iteration: 6
1.125000% labels changed
Iteration: 7
0.625000% labels changed
Iteration: 8
0.375000% labels changed
Iteration: 9
Labels unchanged! Terminating k-means.
Iteration: 0
62.500000% labels changed
Iteration: 1
10.625000% labels changed
Iteration: 2
4.750000% labels changed
Iteration: 3
4.500000% labels changed
Iteration: 4
4.125000% labels changed
Iteration: 5
5.625000% labels changed
Iteration: 6
2.625000% labels changed
Iteration: 7
0.500000% labels changed
Iteration: 8
0.375000% labels changed
Iteration: 9
0.500000% labels changed
Iteration: 10
0.125000% labels changed
Iteration: 11
Labels unchanged! Terminating k-means.
Iteration: 0
62.750000% labels changed
Iteration: 1
24.875000% labels changed
Iteration: 2
13.000000% labels changed
Iteration: 3
5.375000% labels changed
Iteration: 4
1.875000% labels changed
Iteration: 5
1.500000% labels changed
Iteration: 6
1.250000% labels changed
Iteration: 7
0.500000% labels changed
Iteration: 8
0.625000% labels changed
Iteration: 9
0.625000% labels changed
Iteration: 10
0.375000% labels changed
Iteration: 11
0.125000% labels changed
Iteration: 12
0.375000% labels changed
Iteration: 13
0.125000% labels changed
Iteration: 14
Labels unchanged! Terminating k-means.
Iteration: 0
65.500000% labels changed
Iteration: 1
12.125000% labels changed
Iteration: 2
2.750000% labels changed
Iteration: 3
0.625000% labels changed
Iteration: 4
0.375000% labels changed
Iteration: 5
0.250000% labels changed
Iteration: 6
0.250000% labels changed
Iteration: 7
0.375000% labels changed
Iteration: 8
0.125000% labels changed
Iteration: 9
0.250000% labels changed
Iteration: 10
0.625000% labels changed
Iteration: 11
1.250000% labels changed
Iteration: 12
2.250000% labels changed
Iteration: 13
4.500000% labels changed
Iteration: 14
7.125000% labels changed
Iteration: 15
5.000000% labels changed
Iteration: 16
2.250000% labels changed
Iteration: 17
1.625000% labels changed
Iteration: 18
1.250000% labels changed
Iteration: 19
0.500000% labels changed
Iteration: 20
0.625000% labels changed
Iteration: 21
0.750000% labels changed
Iteration: 22
0.375000% labels changed
Iteration: 23
0.375000% labels changed
Iteration: 24
0.125000% labels changed
Iteration: 25
0.125000% labels changed
Iteration: 26
Labels unchanged! Terminating k-means.
Iteration: 0
63.125000% labels changed
Iteration: 1
25.750000% labels changed
Iteration: 2
7.375000% labels changed
Iteration: 3
2.000000% labels changed
Iteration: 4
1.500000% labels changed
Iteration: 5
1.875000% labels changed
Iteration: 6
1.625000% labels changed
Iteration: 7
1.250000% labels changed
Iteration: 8
0.875000% labels changed
Iteration: 9
0.625000% labels changed
Iteration: 10
0.875000% labels changed
Iteration: 11
0.375000% labels changed
Iteration: 12
0.375000% labels changed
Iteration: 13
0.250000% labels changed
Iteration: 14
0.125000% labels changed
Iteration: 15
Labels unchanged! Terminating k-means.
Iteration: 0
61.500000% labels changed
Iteration: 1
13.375000% labels changed
Iteration: 2
6.125000% labels changed
Iteration: 3
2.000000% labels changed
Iteration: 4
0.250000% labels changed
Iteration: 5
Labels unchanged! Terminating k-means.
Iteration: 0
70.750000% labels changed
Iteration: 1
19.750000% labels changed
Iteration: 2
5.625000% labels changed
Iteration: 3
2.875000% labels changed
Iteration: 4
3.875000% labels changed
Iteration: 5
2.875000% labels changed
Iteration: 6
3.125000% labels changed
Iteration: 7
3.750000% labels changed
Iteration: 8
2.625000% labels changed
Iteration: 9
1.000000% labels changed
Iteration: 10
1.375000% labels changed
Iteration: 11
1.875000% labels changed
Iteration: 12
1.375000% labels changed
Iteration: 13
0.500000% labels changed
Iteration: 14
0.750000% labels changed
Iteration: 15
0.125000% labels changed
Iteration: 16
0.250000% labels changed
Iteration: 17
0.125000% labels changed
Iteration: 18
0.250000% labels changed
Iteration: 19
0.375000% labels changed
Iteration: 20
0.125000% labels changed
Iteration: 21
0.250000% labels changed
Iteration: 22
0.250000% labels changed
Iteration: 23
0.125000% labels changed
Iteration: 24
0.125000% labels changed
Iteration: 25
0.375000% labels changed
Iteration: 26
0.250000% labels changed
Iteration: 27
0.375000% labels changed
Iteration: 28
0.750000% labels changed
Iteration: 29
0.750000% labels changed
Iteration: 30
1.125000% labels changed
Iteration: 31
2.000000% labels changed
Iteration: 32
4.250000% labels changed
Iteration: 33
7.000000% labels changed
Iteration: 34
6.250000% labels changed
Iteration: 35
4.250000% labels changed
Iteration: 36
4.500000% labels changed
Iteration: 37
4.500000% labels changed
Iteration: 38
2.375000% labels changed
Iteration: 39
1.625000% labels changed
Iteration: 40
1.375000% labels changed
Iteration: 41
0.750000% labels changed
Iteration: 42
1.125000% labels changed
Iteration: 43
0.375000% labels changed
Iteration: 44
Labels unchanged! Terminating k-means.
Iteration: 0
71.875000% labels changed
Iteration: 1
16.500000% labels changed
Iteration: 2
5.750000% labels changed
Iteration: 3
4.000000% labels changed
Iteration: 4
4.750000% labels changed
Iteration: 5
7.625000% labels changed
Iteration: 6
5.625000% labels changed
Iteration: 7
4.500000% labels changed
Iteration: 8
2.375000% labels changed
Iteration: 9
2.000000% labels changed
Iteration: 10
0.750000% labels changed
Iteration: 11
1.500000% labels changed
Iteration: 12
0.875000% labels changed
Iteration: 13
0.500000% labels changed
Iteration: 14
0.250000% labels changed
Iteration: 15
0.125000% labels changed
Iteration: 16
Labels unchanged! Terminating k-means.
Iteration: 0
69.125000% labels changed
Iteration: 1
27.375000% labels changed
Iteration: 2
14.000000% labels changed
Iteration: 3
6.500000% labels changed
Iteration: 4
2.250000% labels changed
Iteration: 5
1.000000% labels changed
Iteration: 6
0.750000% labels changed
Iteration: 7
0.625000% labels changed
Iteration: 8
0.750000% labels changed
Iteration: 9
0.875000% labels changed
Iteration: 10
1.000000% labels changed
Iteration: 11
0.375000% labels changed
Iteration: 12
0.125000% labels changed
Iteration: 13
0.125000% labels changed
Iteration: 14
Labels unchanged! Terminating k-means.
Iteration: 0
69.625000% labels changed
Iteration: 1
25.375000% labels changed
Iteration: 2
9.125000% labels changed
Iteration: 3
2.875000% labels changed
Iteration: 4
2.000000% labels changed
Iteration: 5
1.625000% labels changed
Iteration: 6
1.000000% labels changed
Iteration: 7
0.500000% labels changed
Iteration: 8
0.250000% labels changed
Iteration: 9
0.250000% labels changed
Iteration: 10
Labels unchanged! Terminating k-means.
Iteration: 0
71.500000% labels changed
Iteration: 1
23.250000% labels changed
Iteration: 2
14.875000% labels changed
Iteration: 3
8.250000% labels changed
Iteration: 4
5.875000% labels changed
Iteration: 5
4.375000% labels changed
Iteration: 6
2.750000% labels changed
Iteration: 7
2.125000% labels changed
Iteration: 8
1.625000% labels changed
Iteration: 9
1.125000% labels changed
Iteration: 10
0.500000% labels changed
Iteration: 11
0.250000% labels changed
Iteration: 12
0.250000% labels changed
Iteration: 13
Labels unchanged! Terminating k-means.
Iteration: 0
72.750000% labels changed
Iteration: 1
31.000000% labels changed
Iteration: 2
8.500000% labels changed
Iteration: 3
2.875000% labels changed
Iteration: 4
2.000000% labels changed
Iteration: 5
0.750000% labels changed
Iteration: 6
0.125000% labels changed
Iteration: 7
Labels unchanged! Terminating k-means.
Iteration: 0
75.000000% labels changed
Iteration: 1
19.250000% labels changed
Iteration: 2
7.125000% labels changed
Iteration: 3
1.750000% labels changed
Iteration: 4
0.625000% labels changed
Iteration: 5
0.375000% labels changed
Iteration: 6
0.375000% labels changed
Iteration: 7
0.375000% labels changed
Iteration: 8
0.375000% labels changed
Iteration: 9
0.625000% labels changed
Iteration: 10
0.250000% labels changed
Iteration: 11
0.125000% labels changed
Iteration: 12
Labels unchanged! Terminating k-means.
Iteration: 0
74.250000% labels changed
Iteration: 1
20.125000% labels changed
Iteration: 2
16.875000% labels changed
Iteration: 3
9.125000% labels changed
Iteration: 4
4.625000% labels changed
Iteration: 5
2.375000% labels changed
Iteration: 6
1.125000% labels changed
Iteration: 7
0.125000% labels changed
Iteration: 8
0.125000% labels changed
Iteration: 9
Labels unchanged! Terminating k-means.
Iteration: 0
75.125000% labels changed
Iteration: 1
27.750000% labels changed
Iteration: 2
11.750000% labels changed
Iteration: 3
5.375000% labels changed
Iteration: 4
5.250000% labels changed
Iteration: 5
5.500000% labels changed
Iteration: 6
4.000000% labels changed
Iteration: 7
4.875000% labels changed
Iteration: 8
3.375000% labels changed
Iteration: 9
2.875000% labels changed
Iteration: 10
1.750000% labels changed
Iteration: 11
1.125000% labels changed
Iteration: 12
1.250000% labels changed
Iteration: 13
0.750000% labels changed
Iteration: 14
0.750000% labels changed
Iteration: 15
0.250000% labels changed
Iteration: 16
0.125000% labels changed
Iteration: 17
Labels unchanged! Terminating k-means.
Iteration: 0
75.250000% labels changed
Iteration: 1
26.500000% labels changed
Iteration: 2
10.750000% labels changed
Iteration: 3
8.875000% labels changed
Iteration: 4
7.250000% labels changed
Iteration: 5
5.500000% labels changed
Iteration: 6
4.625000% labels changed
Iteration: 7
4.625000% labels changed
Iteration: 8
3.625000% labels changed
Iteration: 9
2.000000% labels changed
Iteration: 10
1.125000% labels changed
Iteration: 11
0.750000% labels changed
Iteration: 12
Labels unchanged! Terminating k-means.
Iteration: 0
78.250000% labels changed
Iteration: 1
25.875000% labels changed
Iteration: 2
12.250000% labels changed
Iteration: 3
6.000000% labels changed
Iteration: 4
4.750000% labels changed
Iteration: 5
3.875000% labels changed
Iteration: 6
2.750000% labels changed
Iteration: 7
1.375000% labels changed
Iteration: 8
1.375000% labels changed
Iteration: 9
1.625000% labels changed
Iteration: 10
1.750000% labels changed
Iteration: 11
1.750000% labels changed
Iteration: 12
1.625000% labels changed
Iteration: 13
1.375000% labels changed
Iteration: 14
1.625000% labels changed
Iteration: 15
1.250000% labels changed
Iteration: 16
1.375000% labels changed
Iteration: 17
1.500000% labels changed
Iteration: 18
1.625000% labels changed
Iteration: 19
0.750000% labels changed
Iteration: 20
0.250000% labels changed
Iteration: 21
0.250000% labels changed
Iteration: 22
0.250000% labels changed
Iteration: 23
0.125000% labels changed
Iteration: 24
0.250000% labels changed
Iteration: 25
0.375000% labels changed
Iteration: 26
0.500000% labels changed
Iteration: 27
0.500000% labels changed
Iteration: 28
0.375000% labels changed
Iteration: 29
0.250000% labels changed
Iteration: 30
0.250000% labels changed
Iteration: 31
0.250000% labels changed
Iteration: 32
Labels unchanged! Terminating k-means.
Iteration: 0
78.500000% labels changed
Iteration: 1
27.750000% labels changed
Iteration: 2
9.625000% labels changed
Iteration: 3
5.625000% labels changed
Iteration: 4
5.000000% labels changed
Iteration: 5
4.000000% labels changed
Iteration: 6
2.500000% labels changed
Iteration: 7
2.500000% labels changed
Iteration: 8
1.375000% labels changed
Iteration: 9
1.250000% labels changed
Iteration: 10
0.500000% labels changed
Iteration: 11
0.375000% labels changed
Iteration: 12
0.250000% labels changed
Iteration: 13
Labels unchanged! Terminating k-means.
Iteration: 0
78.250000% labels changed
Iteration: 1
22.500000% labels changed
Iteration: 2
12.500000% labels changed
Iteration: 3
6.250000% labels changed
Iteration: 4
4.250000% labels changed
Iteration: 5
3.000000% labels changed
Iteration: 6
2.000000% labels changed
Iteration: 7
1.125000% labels changed
Iteration: 8
0.375000% labels changed
Iteration: 9
0.750000% labels changed
Iteration: 10
0.125000% labels changed
Iteration: 11
0.125000% labels changed
Iteration: 12
Labels unchanged! Terminating k-means.
Iteration: 0
77.250000% labels changed
Iteration: 1
24.000000% labels changed
Iteration: 2
8.750000% labels changed
Iteration: 3
4.000000% labels changed
Iteration: 4
3.125000% labels changed
Iteration: 5
2.250000% labels changed
Iteration: 6
0.625000% labels changed
Iteration: 7
0.375000% labels changed
Iteration: 8
0.375000% labels changed
Iteration: 9
0.625000% labels changed
Iteration: 10
0.625000% labels changed
Iteration: 11
0.375000% labels changed
Iteration: 12
0.500000% labels changed
Iteration: 13
0.375000% labels changed
Iteration: 14
0.375000% labels changed
Iteration: 15
0.375000% labels changed
Iteration: 16
Labels unchanged! Terminating k-means.
Iteration: 0
76.875000% labels changed
Iteration: 1
23.375000% labels changed
Iteration: 2
7.375000% labels changed
Iteration: 3
4.875000% labels changed
Iteration: 4
4.250000% labels changed
Iteration: 5
2.750000% labels changed
Iteration: 6
2.625000% labels changed
Iteration: 7
2.625000% labels changed
Iteration: 8
2.500000% labels changed
Iteration: 9
1.875000% labels changed
Iteration: 10
1.000000% labels changed
Iteration: 11
0.875000% labels changed
Iteration: 12
1.000000% labels changed
Iteration: 13
1.000000% labels changed
Iteration: 14
1.125000% labels changed
Iteration: 15
2.125000% labels changed
Iteration: 16
1.500000% labels changed
Iteration: 17
1.750000% labels changed
Iteration: 18
2.000000% labels changed
Iteration: 19
1.750000% labels changed
Iteration: 20
1.625000% labels changed
Iteration: 21
1.625000% labels changed
Iteration: 22
1.750000% labels changed
Iteration: 23
2.250000% labels changed
Iteration: 24
1.625000% labels changed
Iteration: 25
0.500000% labels changed
Iteration: 26
0.375000% labels changed
Iteration: 27
0.250000% labels changed
Iteration: 28
Labels unchanged! Terminating k-means.
Iteration: 0
80.500000% labels changed
Iteration: 1
26.875000% labels changed
Iteration: 2
12.125000% labels changed
Iteration: 3
7.000000% labels changed
Iteration: 4
3.625000% labels changed
Iteration: 5
3.375000% labels changed
Iteration: 6
3.375000% labels changed
Iteration: 7
3.125000% labels changed
Iteration: 8
2.375000% labels changed
Iteration: 9
2.875000% labels changed
Iteration: 10
2.625000% labels changed
Iteration: 11
2.125000% labels changed
Iteration: 12
2.875000% labels changed
Iteration: 13
2.125000% labels changed
Iteration: 14
1.625000% labels changed
Iteration: 15
2.250000% labels changed
Iteration: 16
1.375000% labels changed
Iteration: 17
1.375000% labels changed
Iteration: 18
1.000000% labels changed
Iteration: 19
0.375000% labels changed
Iteration: 20
0.375000% labels changed
Iteration: 21
Labels unchanged! Terminating k-means.
Iteration: 0
80.250000% labels changed
Iteration: 1
13.250000% labels changed
Iteration: 2
7.500000% labels changed
Iteration: 3
5.625000% labels changed
Iteration: 4
4.125000% labels changed
Iteration: 5
2.750000% labels changed
Iteration: 6
2.375000% labels changed
Iteration: 7
1.750000% labels changed
Iteration: 8
1.750000% labels changed
Iteration: 9
0.875000% labels changed
Iteration: 10
0.375000% labels changed
Iteration: 11
Labels unchanged! Terminating k-means.
Iteration: 0
78.500000% labels changed
Iteration: 1
16.750000% labels changed
Iteration: 2
7.250000% labels changed
Iteration: 3
6.125000% labels changed
Iteration: 4
5.875000% labels changed
Iteration: 5
5.500000% labels changed
Iteration: 6
4.500000% labels changed
Iteration: 7
3.750000% labels changed
Iteration: 8
2.750000% labels changed
Iteration: 9
2.000000% labels changed
Iteration: 10
2.000000% labels changed
Iteration: 11
0.625000% labels changed
Iteration: 12
0.125000% labels changed
Iteration: 13
Labels unchanged! Terminating k-means.
Iteration: 0
80.250000% labels changed
Iteration: 1
27.750000% labels changed
Iteration: 2
12.125000% labels changed
Iteration: 3
8.875000% labels changed
Iteration: 4
5.500000% labels changed
Iteration: 5
6.000000% labels changed
Iteration: 6
8.250000% labels changed
Iteration: 7
7.250000% labels changed
Iteration: 8
5.000000% labels changed
Iteration: 9
3.750000% labels changed
Iteration: 10
2.750000% labels changed
Iteration: 11
2.625000% labels changed
Iteration: 12
1.000000% labels changed
Iteration: 13
1.375000% labels changed
Iteration: 14
0.875000% labels changed
Iteration: 15
0.500000% labels changed
Iteration: 16
Labels unchanged! Terminating k-means.
Iteration: 0
79.125000% labels changed
Iteration: 1
29.625000% labels changed
Iteration: 2
14.250000% labels changed
Iteration: 3
9.875000% labels changed
Iteration: 4
5.875000% labels changed
Iteration: 5
4.250000% labels changed
Iteration: 6
3.250000% labels changed
Iteration: 7
3.375000% labels changed
Iteration: 8
3.500000% labels changed
Iteration: 9
3.125000% labels changed
Iteration: 10
2.875000% labels changed
Iteration: 11
2.875000% labels changed
Iteration: 12
1.750000% labels changed
Iteration: 13
1.875000% labels changed
Iteration: 14
2.625000% labels changed
Iteration: 15
1.625000% labels changed
Iteration: 16
1.000000% labels changed
Iteration: 17
0.750000% labels changed
Iteration: 18
0.375000% labels changed
Iteration: 19
0.125000% labels changed
Iteration: 20
0.125000% labels changed
Iteration: 21
Labels unchanged! Terminating k-means.
Iteration: 0
84.000000% labels changed
Iteration: 1
34.500000% labels changed
Iteration: 2
10.750000% labels changed
Iteration: 3
5.125000% labels changed
Iteration: 4
3.500000% labels changed
Iteration: 5
3.125000% labels changed
Iteration: 6
3.000000% labels changed
Iteration: 7
1.875000% labels changed
Iteration: 8
1.125000% labels changed
Iteration: 9
0.750000% labels changed
Iteration: 10
0.250000% labels changed
Iteration: 11
0.125000% labels changed
Iteration: 12
0.250000% labels changed
Iteration: 13
Labels unchanged! Terminating k-means.
Iteration: 0
82.875000% labels changed
Iteration: 1
21.375000% labels changed
Iteration: 2
8.125000% labels changed
Iteration: 3
3.875000% labels changed
Iteration: 4
1.375000% labels changed
Iteration: 5
0.500000% labels changed
Iteration: 6
0.250000% labels changed
Iteration: 7
Labels unchanged! Terminating k-means.
Iteration: 0
82.875000% labels changed
Iteration: 1
24.375000% labels changed
Iteration: 2
9.000000% labels changed
Iteration: 3
3.250000% labels changed
Iteration: 4
3.125000% labels changed
Iteration: 5
3.000000% labels changed
Iteration: 6
2.250000% labels changed
Iteration: 7
2.000000% labels changed
Iteration: 8
1.500000% labels changed
Iteration: 9
0.750000% labels changed
Iteration: 10
1.000000% labels changed
Iteration: 11
0.750000% labels changed
Iteration: 12
0.875000% labels changed
Iteration: 13
0.750000% labels changed
Iteration: 14
0.875000% labels changed
Iteration: 15
0.500000% labels changed
Iteration: 16
0.750000% labels changed
Iteration: 17
0.875000% labels changed
Iteration: 18
0.500000% labels changed
Iteration: 19
0.500000% labels changed
Iteration: 20
0.625000% labels changed
Iteration: 21
0.750000% labels changed
Iteration: 22
0.250000% labels changed
Iteration: 23
Labels unchanged! Terminating k-means.
Iteration: 0
81.750000% labels changed
Iteration: 1
25.125000% labels changed
Iteration: 2
14.750000% labels changed
Iteration: 3
10.375000% labels changed
Iteration: 4
6.875000% labels changed
Iteration: 5
4.500000% labels changed
Iteration: 6
2.875000% labels changed
Iteration: 7
2.250000% labels changed
Iteration: 8
1.750000% labels changed
Iteration: 9
1.875000% labels changed
Iteration: 10
1.625000% labels changed
Iteration: 11
1.000000% labels changed
Iteration: 12
1.000000% labels changed
Iteration: 13
0.750000% labels changed
Iteration: 14
0.750000% labels changed
Iteration: 15
0.500000% labels changed
Iteration: 16
0.625000% labels changed
Iteration: 17
0.750000% labels changed
Iteration: 18
0.625000% labels changed
Iteration: 19
0.375000% labels changed
Iteration: 20
0.875000% labels changed
Iteration: 21
0.250000% labels changed
Iteration: 22
0.375000% labels changed
Iteration: 23
0.375000% labels changed
Iteration: 24
0.375000% labels changed
Iteration: 25
0.500000% labels changed
Iteration: 26
0.250000% labels changed
Iteration: 27
0.125000% labels changed
Iteration: 28
0.125000% labels changed
Iteration: 29
0.250000% labels changed
Iteration: 30
0.625000% labels changed
Iteration: 31
0.500000% labels changed
Iteration: 32
0.625000% labels changed
Iteration: 33
0.500000% labels changed
Iteration: 34
0.250000% labels changed
Iteration: 35
0.250000% labels changed
Iteration: 36
Labels unchanged! Terminating k-means.
Iteration: 0
82.125000% labels changed
Iteration: 1
29.750000% labels changed
Iteration: 2
15.500000% labels changed
Iteration: 3
7.375000% labels changed
Iteration: 4
7.250000% labels changed
Iteration: 5
6.875000% labels changed
Iteration: 6
5.250000% labels changed
Iteration: 7
4.625000% labels changed
Iteration: 8
2.750000% labels changed
Iteration: 9
1.250000% labels changed
Iteration: 10
0.125000% labels changed
Iteration: 11
Labels unchanged! Terminating k-means.
Iteration: 0
82.000000% labels changed
Iteration: 1
31.625000% labels changed
Iteration: 2
10.125000% labels changed
Iteration: 3
3.375000% labels changed
Iteration: 4
3.000000% labels changed
Iteration: 5
1.125000% labels changed
Iteration: 6
1.250000% labels changed
Iteration: 7
1.125000% labels changed
Iteration: 8
0.625000% labels changed
Iteration: 9
0.250000% labels changed
Iteration: 10
0.125000% labels changed
Iteration: 11
Labels unchanged! Terminating k-means.
Iteration: 0
84.500000% labels changed
Iteration: 1
23.625000% labels changed
Iteration: 2
8.375000% labels changed
Iteration: 3
6.000000% labels changed
Iteration: 4
6.000000% labels changed
Iteration: 5
4.125000% labels changed
Iteration: 6
3.375000% labels changed
Iteration: 7
2.750000% labels changed
Iteration: 8
2.750000% labels changed
Iteration: 9
2.375000% labels changed
Iteration: 10
1.750000% labels changed
Iteration: 11
1.500000% labels changed
Iteration: 12
1.500000% labels changed
Iteration: 13
0.875000% labels changed
Iteration: 14
0.875000% labels changed
Iteration: 15
0.500000% labels changed
Iteration: 16
0.750000% labels changed
Iteration: 17
0.750000% labels changed
Iteration: 18
0.125000% labels changed
Iteration: 19
0.250000% labels changed
Iteration: 20
Labels unchanged! Terminating k-means.
Iteration: 0
85.375000% labels changed
Iteration: 1
28.250000% labels changed
Iteration: 2
9.750000% labels changed
Iteration: 3
3.625000% labels changed
Iteration: 4
2.500000% labels changed
Iteration: 5
1.250000% labels changed
Iteration: 6
0.250000% labels changed
Iteration: 7
0.375000% labels changed
Iteration: 8
0.500000% labels changed
Iteration: 9
0.625000% labels changed
Iteration: 10
0.750000% labels changed
Iteration: 11
0.750000% labels changed
Iteration: 12
2.000000% labels changed
Iteration: 13
1.500000% labels changed
Iteration: 14
1.000000% labels changed
Iteration: 15
0.750000% labels changed
Iteration: 16
0.625000% labels changed
Iteration: 17
0.500000% labels changed
Iteration: 18
0.375000% labels changed
Iteration: 19
0.500000% labels changed
Iteration: 20
0.250000% labels changed
Iteration: 21
Labels unchanged! Terminating k-means.
Iteration: 0
83.750000% labels changed
Iteration: 1
25.000000% labels changed
Iteration: 2
13.125000% labels changed
Iteration: 3
7.125000% labels changed
Iteration: 4
4.375000% labels changed
Iteration: 5
3.250000% labels changed
Iteration: 6
2.250000% labels changed
Iteration: 7
1.500000% labels changed
Iteration: 8
0.750000% labels changed
Iteration: 9
0.750000% labels changed
Iteration: 10
0.250000% labels changed
Iteration: 11
0.375000% labels changed
Iteration: 12
0.375000% labels changed
Iteration: 13
0.250000% labels changed
Iteration: 14
0.500000% labels changed
Iteration: 15
0.375000% labels changed
Iteration: 16
0.500000% labels changed
Iteration: 17
0.375000% labels changed
Iteration: 18
0.500000% labels changed
Iteration: 19
0.375000% labels changed
Iteration: 20
0.750000% labels changed
Iteration: 21
0.250000% labels changed
Iteration: 22
0.500000% labels changed
Iteration: 23
0.625000% labels changed
Iteration: 24
0.875000% labels changed
Iteration: 25
1.000000% labels changed
Iteration: 26
1.000000% labels changed
Iteration: 27
1.125000% labels changed
Iteration: 28
0.500000% labels changed
Iteration: 29
0.125000% labels changed
Iteration: 30
0.125000% labels changed
Iteration: 31
Labels unchanged! Terminating k-means.
Iteration: 0
82.625000% labels changed
Iteration: 1
31.250000% labels changed
Iteration: 2
15.750000% labels changed
Iteration: 3
8.125000% labels changed
Iteration: 4
3.875000% labels changed
Iteration: 5
2.500000% labels changed
Iteration: 6
1.875000% labels changed
Iteration: 7
1.625000% labels changed
Iteration: 8
0.625000% labels changed
Iteration: 9
0.500000% labels changed
Iteration: 10
0.125000% labels changed
Iteration: 11
Labels unchanged! Terminating k-means.
Iteration: 0
84.000000% labels changed
Iteration: 1
36.375000% labels changed
Iteration: 2
15.375000% labels changed
Iteration: 3
8.750000% labels changed
Iteration: 4
4.500000% labels changed
Iteration: 5
2.375000% labels changed
Iteration: 6
2.125000% labels changed
Iteration: 7
1.750000% labels changed
Iteration: 8
1.625000% labels changed
Iteration: 9
1.250000% labels changed
Iteration: 10
1.125000% labels changed
Iteration: 11
1.125000% labels changed
Iteration: 12
1.625000% labels changed
Iteration: 13
1.500000% labels changed
Iteration: 14
1.125000% labels changed
Iteration: 15
0.875000% labels changed
Iteration: 16
0.750000% labels changed
Iteration: 17
0.250000% labels changed
Iteration: 18
Labels unchanged! Terminating k-means.
Iteration: 0
85.000000% labels changed
Iteration: 1
29.500000% labels changed
Iteration: 2
10.250000% labels changed
Iteration: 3
7.375000% labels changed
Iteration: 4
6.625000% labels changed
Iteration: 5
5.625000% labels changed
Iteration: 6
4.250000% labels changed
Iteration: 7
3.000000% labels changed
Iteration: 8
2.125000% labels changed
Iteration: 9
2.000000% labels changed
Iteration: 10
0.750000% labels changed
Iteration: 11
0.500000% labels changed
Iteration: 12
Labels unchanged! Terminating k-means.
Iteration: 0
84.375000% labels changed
Iteration: 1
31.875000% labels changed
Iteration: 2
11.500000% labels changed
Iteration: 3
5.375000% labels changed
Iteration: 4
3.500000% labels changed
Iteration: 5
1.625000% labels changed
Iteration: 6
1.000000% labels changed
Iteration: 7
1.125000% labels changed
Iteration: 8
0.875000% labels changed
Iteration: 9
1.500000% labels changed
Iteration: 10
0.750000% labels changed
Iteration: 11
0.625000% labels changed
Iteration: 12
0.250000% labels changed
Iteration: 13
0.125000% labels changed
Iteration: 14
0.125000% labels changed
Iteration: 15
Labels unchanged! Terminating k-means.
Iteration: 0
84.500000% labels changed
Iteration: 1
28.625000% labels changed
Iteration: 2
13.375000% labels changed
Iteration: 3
6.250000% labels changed
Iteration: 4
4.875000% labels changed
Iteration: 5
4.750000% labels changed
Iteration: 6
3.250000% labels changed
Iteration: 7
3.125000% labels changed
Iteration: 8
3.125000% labels changed
Iteration: 9
2.875000% labels changed
Iteration: 10
2.250000% labels changed
Iteration: 11
2.375000% labels changed
Iteration: 12
1.750000% labels changed
Iteration: 13
2.250000% labels changed
Iteration: 14
2.000000% labels changed
Iteration: 15
1.625000% labels changed
Iteration: 16
1.250000% labels changed
Iteration: 17
1.375000% labels changed
Iteration: 18
1.375000% labels changed
Iteration: 19
0.625000% labels changed
Iteration: 20
0.500000% labels changed
Iteration: 21
0.250000% labels changed
Iteration: 22
0.375000% labels changed
Iteration: 23
0.500000% labels changed
Iteration: 24
0.625000% labels changed
Iteration: 25
0.750000% labels changed
Iteration: 26
0.125000% labels changed
Iteration: 27
0.125000% labels changed
Iteration: 28
0.250000% labels changed
Iteration: 29
0.125000% labels changed
Iteration: 30
0.250000% labels changed
Iteration: 31
0.250000% labels changed
Iteration: 32
Labels unchanged! Terminating k-means.
Iteration: 0
82.000000% labels changed
Iteration: 1
28.750000% labels changed
Iteration: 2
16.750000% labels changed
Iteration: 3
8.875000% labels changed
Iteration: 4
3.875000% labels changed
Iteration: 5
2.125000% labels changed
Iteration: 6
1.750000% labels changed
Iteration: 7
1.500000% labels changed
Iteration: 8
0.375000% labels changed
Iteration: 9
0.125000% labels changed
Iteration: 10
0.250000% labels changed
Iteration: 11
0.250000% labels changed
Iteration: 12
Labels unchanged! Terminating k-means.
Iteration: 0
85.750000% labels changed
Iteration: 1
20.375000% labels changed
Iteration: 2
10.500000% labels changed
Iteration: 3
6.625000% labels changed
Iteration: 4
4.125000% labels changed
Iteration: 5
2.375000% labels changed
Iteration: 6
2.000000% labels changed
Iteration: 7
1.750000% labels changed
Iteration: 8
1.375000% labels changed
Iteration: 9
1.500000% labels changed
Iteration: 10
1.000000% labels changed
Iteration: 11
0.250000% labels changed
Iteration: 12
0.250000% labels changed
Iteration: 13
0.375000% labels changed
Iteration: 14
0.500000% labels changed
Iteration: 15
0.375000% labels changed
Iteration: 16
0.375000% labels changed
Iteration: 17
0.250000% labels changed
Iteration: 18
0.250000% labels changed
Iteration: 19
0.125000% labels changed
Iteration: 20
Labels unchanged! Terminating k-means.
Iteration: 0
85.250000% labels changed
Iteration: 1
23.875000% labels changed
Iteration: 2
9.625000% labels changed
Iteration: 3
4.125000% labels changed
Iteration: 4
2.500000% labels changed
Iteration: 5
2.000000% labels changed
Iteration: 6
1.250000% labels changed
Iteration: 7
0.375000% labels changed
Iteration: 8
0.125000% labels changed
Iteration: 9
0.500000% labels changed
Iteration: 10
0.375000% labels changed
Iteration: 11
0.250000% labels changed
Iteration: 12
0.500000% labels changed
Iteration: 13
0.250000% labels changed
Iteration: 14
0.250000% labels changed
Iteration: 15
0.250000% labels changed
Iteration: 16
Labels unchanged! Terminating k-means.
Iteration: 0
85.250000% labels changed
Iteration: 1
23.125000% labels changed
Iteration: 2
10.875000% labels changed
Iteration: 3
4.625000% labels changed
Iteration: 4
3.750000% labels changed
Iteration: 5
3.625000% labels changed
Iteration: 6
3.000000% labels changed
Iteration: 7
2.250000% labels changed
Iteration: 8
2.750000% labels changed
Iteration: 9
2.750000% labels changed
Iteration: 10
3.125000% labels changed
Iteration: 11
1.875000% labels changed
Iteration: 12
1.750000% labels changed
Iteration: 13
1.750000% labels changed
Iteration: 14
0.875000% labels changed
Iteration: 15
0.875000% labels changed
Iteration: 16
0.250000% labels changed
Iteration: 17
0.250000% labels changed
Iteration: 18
0.500000% labels changed
Iteration: 19
0.875000% labels changed
Iteration: 20
0.375000% labels changed
Iteration: 21
0.250000% labels changed
Iteration: 22
0.250000% labels changed
Iteration: 23
Labels unchanged! Terminating k-means.
Iteration: 0
86.000000% labels changed
Iteration: 1
32.125000% labels changed
Iteration: 2
14.375000% labels changed
Iteration: 3
7.250000% labels changed
Iteration: 4
3.750000% labels changed
Iteration: 5
3.125000% labels changed
Iteration: 6
3.750000% labels changed
Iteration: 7
3.125000% labels changed
Iteration: 8
2.250000% labels changed
Iteration: 9
1.000000% labels changed
Iteration: 10
0.625000% labels changed
Iteration: 11
0.250000% labels changed
Iteration: 12
Labels unchanged! Terminating k-means.
Iteration: 0
85.750000% labels changed
Iteration: 1
18.875000% labels changed
Iteration: 2
11.375000% labels changed
Iteration: 3
8.875000% labels changed
Iteration: 4
5.500000% labels changed
Iteration: 5
3.875000% labels changed
Iteration: 6
3.125000% labels changed
Iteration: 7
4.375000% labels changed
Iteration: 8
3.375000% labels changed
Iteration: 9
2.875000% labels changed
Iteration: 10
2.375000% labels changed
Iteration: 11
1.500000% labels changed
Iteration: 12
1.000000% labels changed
Iteration: 13
0.750000% labels changed
Iteration: 14
0.375000% labels changed
Iteration: 15
0.125000% labels changed
Iteration: 16
Labels unchanged! Terminating k-means.
Iteration: 0
89.625000% labels changed
Iteration: 1
37.500000% labels changed
Iteration: 2
16.875000% labels changed
Iteration: 3
9.625000% labels changed
Iteration: 4
5.250000% labels changed
Iteration: 5
3.125000% labels changed
Iteration: 6
2.000000% labels changed
Iteration: 7
0.750000% labels changed
Iteration: 8
0.625000% labels changed
Iteration: 9
0.750000% labels changed
Iteration: 10
0.250000% labels changed
Iteration: 11
0.500000% labels changed
Iteration: 12
0.500000% labels changed
Iteration: 13
0.250000% labels changed
Iteration: 14
Labels unchanged! Terminating k-means.
Iteration: 0
84.875000% labels changed
Iteration: 1
40.375000% labels changed
Iteration: 2
8.750000% labels changed
Iteration: 3
3.875000% labels changed
Iteration: 4
2.500000% labels changed
Iteration: 5
1.750000% labels changed
Iteration: 6
1.625000% labels changed
Iteration: 7
1.250000% labels changed
Iteration: 8
1.250000% labels changed
Iteration: 9
1.250000% labels changed
Iteration: 10
1.250000% labels changed
Iteration: 11
1.250000% labels changed
Iteration: 12
0.625000% labels changed
Iteration: 13
0.750000% labels changed
Iteration: 14
1.000000% labels changed
Iteration: 15
1.125000% labels changed
Iteration: 16
0.500000% labels changed
Iteration: 17
0.500000% labels changed
Iteration: 18
0.500000% labels changed
Iteration: 19
0.750000% labels changed
Iteration: 20
0.500000% labels changed
Iteration: 21
0.500000% labels changed
Iteration: 22
0.375000% labels changed
Iteration: 23
Labels unchanged! Terminating k-means.
Iteration: 0
87.250000% labels changed
Iteration: 1
28.500000% labels changed
Iteration: 2
8.375000% labels changed
Iteration: 3
5.125000% labels changed
Iteration: 4
5.375000% labels changed
Iteration: 5
5.125000% labels changed
Iteration: 6
3.500000% labels changed
Iteration: 7
1.875000% labels changed
Iteration: 8
1.500000% labels changed
Iteration: 9
1.000000% labels changed
Iteration: 10
0.500000% labels changed
Iteration: 11
0.250000% labels changed
Iteration: 12
0.375000% labels changed
Iteration: 13
0.125000% labels changed
Iteration: 14
Labels unchanged! Terminating k-means.
Iteration: 0
85.000000% labels changed
Iteration: 1
25.375000% labels changed
Iteration: 2
10.875000% labels changed
Iteration: 3
7.625000% labels changed
Iteration: 4
5.750000% labels changed
Iteration: 5
3.375000% labels changed
Iteration: 6
2.000000% labels changed
Iteration: 7
1.000000% labels changed
Iteration: 8
0.750000% labels changed
Iteration: 9
0.500000% labels changed
Iteration: 10
0.250000% labels changed
Iteration: 11
0.125000% labels changed
Iteration: 12
0.250000% labels changed
Iteration: 13
0.250000% labels changed
Iteration: 14
0.250000% labels changed
Iteration: 15
0.375000% labels changed
Iteration: 16
0.375000% labels changed
Iteration: 17
0.125000% labels changed
Iteration: 18
Labels unchanged! Terminating k-means.
Iteration: 0
86.625000% labels changed
Iteration: 1
30.250000% labels changed
Iteration: 2
10.875000% labels changed
Iteration: 3
6.500000% labels changed
Iteration: 4
4.000000% labels changed
Iteration: 5
2.750000% labels changed
Iteration: 6
2.500000% labels changed
Iteration: 7
2.375000% labels changed
Iteration: 8
2.375000% labels changed
Iteration: 9
2.000000% labels changed
Iteration: 10
0.750000% labels changed
Iteration: 11
0.500000% labels changed
Iteration: 12
Labels unchanged! Terminating k-means.
Iteration: 0
87.625000% labels changed
Iteration: 1
26.125000% labels changed
Iteration: 2
11.500000% labels changed
Iteration: 3
6.625000% labels changed
Iteration: 4
3.250000% labels changed
Iteration: 5
1.875000% labels changed
Iteration: 6
0.875000% labels changed
Iteration: 7
0.875000% labels changed
Iteration: 8
0.875000% labels changed
Iteration: 9
0.625000% labels changed
Iteration: 10
0.875000% labels changed
Iteration: 11
0.750000% labels changed
Iteration: 12
0.375000% labels changed
Iteration: 13
0.375000% labels changed
Iteration: 14
Labels unchanged! Terminating k-means.
Iteration: 0
86.625000% labels changed
Iteration: 1
31.000000% labels changed
Iteration: 2
12.000000% labels changed
Iteration: 3
5.875000% labels changed
Iteration: 4
3.875000% labels changed
Iteration: 5
2.125000% labels changed
Iteration: 6
0.875000% labels changed
Iteration: 7
0.750000% labels changed
Iteration: 8
0.875000% labels changed
Iteration: 9
1.125000% labels changed
Iteration: 10
0.875000% labels changed
Iteration: 11
0.250000% labels changed
Iteration: 12
0.250000% labels changed
Iteration: 13
0.500000% labels changed
Iteration: 14
0.375000% labels changed
Iteration: 15
0.250000% labels changed
Iteration: 16
0.375000% labels changed
Iteration: 17
0.250000% labels changed
Iteration: 18
0.250000% labels changed
Iteration: 19
Labels unchanged! Terminating k-means.
Iteration: 0
89.250000% labels changed
Iteration: 1
33.500000% labels changed
Iteration: 2
12.750000% labels changed
Iteration: 3
9.125000% labels changed
Iteration: 4
5.250000% labels changed
Iteration: 5
3.500000% labels changed
Iteration: 6
2.625000% labels changed
Iteration: 7
1.625000% labels changed
Iteration: 8
1.750000% labels changed
Iteration: 9
1.375000% labels changed
Iteration: 10
1.500000% labels changed
Iteration: 11
1.250000% labels changed
Iteration: 12
0.750000% labels changed
Iteration: 13
1.000000% labels changed
Iteration: 14
0.750000% labels changed
Iteration: 15
0.625000% labels changed
Iteration: 16
0.500000% labels changed
Iteration: 17
0.375000% labels changed
Iteration: 18
0.375000% labels changed
Iteration: 19
Labels unchanged! Terminating k-means.
Iteration: 0
87.250000% labels changed
Iteration: 1
26.250000% labels changed
Iteration: 2
14.125000% labels changed
Iteration: 3
9.750000% labels changed
Iteration: 4
6.500000% labels changed
Iteration: 5
4.250000% labels changed
Iteration: 6
2.250000% labels changed
Iteration: 7
1.875000% labels changed
Iteration: 8
1.625000% labels changed
Iteration: 9
0.875000% labels changed
Iteration: 10
1.000000% labels changed
Iteration: 11
0.500000% labels changed
Iteration: 12
0.250000% labels changed
Iteration: 13
0.250000% labels changed
Iteration: 14
0.500000% labels changed
Iteration: 15
0.875000% labels changed
Iteration: 16
0.625000% labels changed
Iteration: 17
0.250000% labels changed
Iteration: 18
0.125000% labels changed
Iteration: 19
Labels unchanged! Terminating k-means.
Iteration: 0
88.375000% labels changed
Iteration: 1
30.750000% labels changed
Iteration: 2
11.500000% labels changed
Iteration: 3
8.625000% labels changed
Iteration: 4
5.000000% labels changed
Iteration: 5
3.875000% labels changed
Iteration: 6
3.625000% labels changed
Iteration: 7
2.875000% labels changed
Iteration: 8
2.250000% labels changed
Iteration: 9
1.250000% labels changed
Iteration: 10
0.875000% labels changed
Iteration: 11
0.375000% labels changed
Iteration: 12
0.250000% labels changed
Iteration: 13
0.250000% labels changed
Iteration: 14
0.250000% labels changed
Iteration: 15
0.250000% labels changed
Iteration: 16
0.125000% labels changed
Iteration: 17
Labels unchanged! Terminating k-means.
Iteration: 0
87.875000% labels changed
Iteration: 1
22.625000% labels changed
Iteration: 2
9.625000% labels changed
Iteration: 3
6.125000% labels changed
Iteration: 4
5.875000% labels changed
Iteration: 5
6.500000% labels changed
Iteration: 6
4.625000% labels changed
Iteration: 7
2.000000% labels changed
Iteration: 8
1.500000% labels changed
Iteration: 9
0.375000% labels changed
Iteration: 10
0.375000% labels changed
Iteration: 11
0.500000% labels changed
Iteration: 12
0.500000% labels changed
Iteration: 13
0.500000% labels changed
Iteration: 14
0.875000% labels changed
Iteration: 15
0.625000% labels changed
Iteration: 16
0.875000% labels changed
Iteration: 17
1.250000% labels changed
Iteration: 18
1.000000% labels changed
Iteration: 19
0.500000% labels changed
Iteration: 20
0.250000% labels changed
Iteration: 21
0.250000% labels changed
Iteration: 22
0.375000% labels changed
Iteration: 23
0.250000% labels changed
Iteration: 24
0.375000% labels changed
Iteration: 25
0.125000% labels changed
Iteration: 26
0.125000% labels changed
Iteration: 27
0.125000% labels changed
Iteration: 28
0.250000% labels changed
Iteration: 29
0.250000% labels changed
Iteration: 30
0.500000% labels changed
Iteration: 31
0.375000% labels changed
Iteration: 32
0.250000% labels changed
Iteration: 33
0.375000% labels changed
Iteration: 34
0.250000% labels changed
Iteration: 35
0.625000% labels changed
Iteration: 36
0.125000% labels changed
Iteration: 37
0.125000% labels changed
Iteration: 38
Labels unchanged! Terminating k-means.
Iteration: 0
87.625000% labels changed
Iteration: 1
29.125000% labels changed
Iteration: 2
13.125000% labels changed
Iteration: 3
7.750000% labels changed
Iteration: 4
5.000000% labels changed
Iteration: 5
2.875000% labels changed
Iteration: 6
2.125000% labels changed
Iteration: 7
1.625000% labels changed
Iteration: 8
1.125000% labels changed
Iteration: 9
0.750000% labels changed
Iteration: 10
0.250000% labels changed
Iteration: 11
0.250000% labels changed
Iteration: 12
0.125000% labels changed
Iteration: 13
0.250000% labels changed
Iteration: 14
0.250000% labels changed
Iteration: 15
Labels unchanged! Terminating k-means.
Iteration: 0
86.625000% labels changed
Iteration: 1
25.750000% labels changed
Iteration: 2
13.875000% labels changed
Iteration: 3
8.000000% labels changed
Iteration: 4
5.000000% labels changed
Iteration: 5
3.000000% labels changed
Iteration: 6
2.750000% labels changed
Iteration: 7
1.750000% labels changed
Iteration: 8
1.750000% labels changed
Iteration: 9
1.125000% labels changed
Iteration: 10
0.750000% labels changed
Iteration: 11
0.625000% labels changed
Iteration: 12
0.500000% labels changed
Iteration: 13
0.250000% labels changed
Iteration: 14
0.375000% labels changed
Iteration: 15
0.250000% labels changed
Iteration: 16
0.250000% labels changed
Iteration: 17
0.250000% labels changed
Iteration: 18
0.250000% labels changed
Iteration: 19
0.125000% labels changed
Iteration: 20
Labels unchanged! Terminating k-means.
Iteration: 0
89.125000% labels changed
Iteration: 1
26.375000% labels changed
Iteration: 2
13.875000% labels changed
Iteration: 3
6.500000% labels changed
Iteration: 4
3.625000% labels changed
Iteration: 5
2.375000% labels changed
Iteration: 6
2.000000% labels changed
Iteration: 7
1.125000% labels changed
Iteration: 8
0.625000% labels changed
Iteration: 9
0.250000% labels changed
Iteration: 10
Labels unchanged! Terminating k-means.
Iteration: 0
86.750000% labels changed
Iteration: 1
30.125000% labels changed
Iteration: 2
10.625000% labels changed
Iteration: 3
6.625000% labels changed
Iteration: 4
4.875000% labels changed
Iteration: 5
3.375000% labels changed
Iteration: 6
3.000000% labels changed
Iteration: 7
4.375000% labels changed
Iteration: 8
4.750000% labels changed
Iteration: 9
4.250000% labels changed
Iteration: 10
3.375000% labels changed
Iteration: 11
3.000000% labels changed
Iteration: 12
1.625000% labels changed
Iteration: 13
0.250000% labels changed
Iteration: 14
Labels unchanged! Terminating k-means.
Iteration: 0
89.625000% labels changed
Iteration: 1
36.125000% labels changed
Iteration: 2
12.000000% labels changed
Iteration: 3
7.125000% labels changed
Iteration: 4
6.250000% labels changed
Iteration: 5
3.750000% labels changed
Iteration: 6
2.125000% labels changed
Iteration: 7
1.750000% labels changed
Iteration: 8
1.500000% labels changed
Iteration: 9
1.125000% labels changed
Iteration: 10
1.250000% labels changed
Iteration: 11
1.375000% labels changed
Iteration: 12
1.875000% labels changed
Iteration: 13
1.125000% labels changed
Iteration: 14
1.000000% labels changed
Iteration: 15
0.500000% labels changed
Iteration: 16
0.125000% labels changed
Iteration: 17
Labels unchanged! Terminating k-means.
Iteration: 0
88.000000% labels changed
Iteration: 1
34.375000% labels changed
Iteration: 2
14.625000% labels changed
Iteration: 3
7.250000% labels changed
Iteration: 4
4.500000% labels changed
Iteration: 5
4.000000% labels changed
Iteration: 6
2.875000% labels changed
Iteration: 7
2.125000% labels changed
Iteration: 8
1.750000% labels changed
Iteration: 9
1.375000% labels changed
Iteration: 10
0.250000% labels changed
Iteration: 11
0.500000% labels changed
Iteration: 12
0.125000% labels changed
Iteration: 13
0.125000% labels changed
Iteration: 14
Labels unchanged! Terminating k-means.
Iteration: 0
89.250000% labels changed
Iteration: 1
28.125000% labels changed
Iteration: 2
12.000000% labels changed
Iteration: 3
5.750000% labels changed
Iteration: 4
2.625000% labels changed
Iteration: 5
3.000000% labels changed
Iteration: 6
2.500000% labels changed
Iteration: 7
0.500000% labels changed
Iteration: 8
0.250000% labels changed
Iteration: 9
0.125000% labels changed
Iteration: 10
Labels unchanged! Terminating k-means.
Iteration: 0
90.000000% labels changed
Iteration: 1
28.875000% labels changed
Iteration: 2
13.250000% labels changed
Iteration: 3
8.250000% labels changed
Iteration: 4
5.625000% labels changed
Iteration: 5
4.125000% labels changed
Iteration: 6
3.625000% labels changed
Iteration: 7
1.375000% labels changed
Iteration: 8
0.750000% labels changed
Iteration: 9
Labels unchanged! Terminating k-means.
Iteration: 0
87.625000% labels changed
Iteration: 1
26.250000% labels changed
Iteration: 2
9.375000% labels changed
Iteration: 3
6.375000% labels changed
Iteration: 4
4.375000% labels changed
Iteration: 5
3.250000% labels changed
Iteration: 6
2.375000% labels changed
Iteration: 7
2.000000% labels changed
Iteration: 8
1.000000% labels changed
Iteration: 9
0.750000% labels changed
Iteration: 10
0.500000% labels changed
Iteration: 11
0.750000% labels changed
Iteration: 12
1.250000% labels changed
Iteration: 13
0.875000% labels changed
Iteration: 14
1.125000% labels changed
Iteration: 15
1.375000% labels changed
Iteration: 16
1.500000% labels changed
Iteration: 17
1.625000% labels changed
Iteration: 18
0.875000% labels changed
Iteration: 19
0.625000% labels changed
Iteration: 20
0.375000% labels changed
Iteration: 21
0.250000% labels changed
Iteration: 22
0.125000% labels changed
Iteration: 23
0.250000% labels changed
Iteration: 24
Labels unchanged! Terminating k-means.
The optimal k and CH index are: 4 and 233.02152942588236
In [87]:
# plotting
plt.plot(range(2, k_range), k_range_ch_index, color="blue")
plt.scatter(range(2, k_range), k_range_ch_index, color="blue")
plt.scatter(2 + np.argmax(k_range_ch_index), np.max(k_range_ch_index), color="red", label="optimal point")
plt.xlabel("The number of clusters k")
plt.ylabel("The value of Calinski-Harabasz index")
plt.title("Plot of Average Calinski-Harabasz Index Against Cluster Number k")
plt.legend(loc = 'upper right')
plt.show()
In [88]:
# find the size of clusters found by k-means with the optimal k

opt_k = 2 + np.argmax(k_range_ch_index)
opt_centroids, opt_labels = k_clustering(gene_expression, opt_k, max_iter=70, message=True)
Iteration: 0
71.500000% labels changed
Iteration: 1
29.375000% labels changed
Iteration: 2
11.000000% labels changed
Iteration: 3
9.625000% labels changed
Iteration: 4
7.375000% labels changed
Iteration: 5
4.250000% labels changed
Iteration: 6
2.875000% labels changed
Iteration: 7
1.500000% labels changed
Iteration: 8
1.250000% labels changed
Iteration: 9
0.750000% labels changed
Iteration: 10
0.500000% labels changed
Iteration: 11
0.250000% labels changed
Iteration: 12
Labels unchanged! Terminating k-means.
In [89]:
# construct a dictionary to sum of the size for each cluster
cluster_sizes = defaultdict(int)

for i in range(opt_k):
    cluster_sizes[i] = np.sum(opt_labels==i)

print("The optimal value of k is: ", opt_k)
print("The cluster sizes found by the optimal k are: ", cluster_sizes)
The optimal value of k is:  4
The cluster sizes found by the optimal k are:  defaultdict(<class 'int'>, {0: 303, 1: 135, 2: 221, 3: 141})

2.1.2 Consistency of k-means clustering¶

In [90]:
# optimal k from 2.1.1
opt_k = 2 + np.argmax(k_range_ch_index)

def h_score(opt_k, gene_type):
    """
    Return the average H(C) score based on optimal k
    Arg: 
        opt_k: optimal k from 2.1.1
        gene_type: true label of the gene expressions.
    """
    if type(gene_type) != np.ndarray:
        gene_type = gene_type.to_numpy()

    # initialization
    a = defaultdict(int)
    types = np.unique(gene_type)
    n = len(gene_type)
    h_c, h_ck = 0, 0
    h_c_lis, h_ck_lis = [], []

    for _ in range(5):
        up_centroids, up_labels = k_clustering(gene_expression, opt_k, max_iter=20, message=False)

        for c in types:
            for k in range(opt_k):
                a[(c, k)] = sum((gene_type==c) & (up_labels==k))

        # to compute H(c)
        for c in types:
            sig_ack = sum(a[(c, k)] for k in range(opt_k))
            h_c -= sig_ack/n * np.log(sig_ack/n)
        h_c_lis.append(h_c)
        
        # to compute H(c|k)
        for k in range(opt_k):
            sig_cck = sum(a[(j, k)] for j in types)

            # skip if 0 is encounted in the log
            for c in types:
                if a[(c, k)] != 0:
                    h_ck -= a[(c, k)]/n * np.log(a[(c, k)]/sig_cck)
        h_ck_lis.append(h_ck)

    return 1 - np.mean(h_ck_lis)/np.mean(h_c_lis)
In [91]:
# print the homogeneity score
hc = h_score(opt_k, gene_type)
print("The homogeneity score is: ", hc)
The homogeneity score is:  0.5824509972004817

Comment:¶

First to note: In defining h_score (homogeneity score), skipped some of the nan in a[(c, k)], meaning that none of the points in class c has beenclustered in cluster k. This will lead to a slight underestimation of the true homogeneity score.

By the set-up of the h_score, the higher the score, the better alignment with the original labels. From this perspective, the clustered labels are not highly consistent with the original labels.

Although CH_k gives the optimal clustering case, there could be some reasons for low h_score:

  • k-measn is dependent on initialization. In this very first stage, if several points from different true classes are assigned with the same cluster label, they are likely to form an individual cluster and will affect the outcome of clustering. If time permitts, several representative samples points should be chosen as initial centroids.

  • "CH_index" gives optimal clustering but necessarily the 'best' clustering. It has been shown to be sensitive to the shape of the clusters and the presence of outliers. If the initial assignments of labels give a considerable number of outliers, and they will be counted towards a certain cluster, despite the largest distance. This will consequently effect the results of centroids and thus the result of clutering. If time permits, checks for outliers and other structures of the label initialization should be implemented, and other measures of quality should be used in combination with this.

Note: the homogeneity score increases with k (shown below). This means that the greater the k, the better alignment of the clustering. But too large a k is likely to cause overfitting. Thus, the choice of k should be carefully chosen.

showing that homogeneity score increases with k¶
In [92]:
h_l = []

for k in range(2, 14):
    h_l.append(h_score(k, gene_type))
In [93]:
plt.plot(range(2, 14), h_l, label="h_socre")
plt.scatter(range(2, 14), h_l)
plt.xlabel("cluster number k")
plt.ylabel("homogeneity score")
plt.title("h score against number of clusters")
plt.legend()
plt.show()
results with sk.learn¶
In [94]:
from sklearn.cluster import KMeans
from sklearn.metrics import homogeneity_score


types = np.unique(gene_type)
c_dict = {types[i]: i for i in range(len(types))}
print(c_dict)

# Building the clustering model
kmeans = KMeans(n_clusters=opt_k)
# Training the clustering mode
kmeans.fit(gene_expression)
# Storing the predicted Clustering labels
labels = kmeans.predict(gene_expression)  
# Evaluating the performance
print(homogeneity_score(gene_type, labels))
{'BRCA': 0, 'COAD': 1, 'KIRC': 2, 'LUAD': 3, 'PRAD': 4}
0.586797136972071

2.2 Graph-based analysis (20 marks)¶

For a network, centrality (representative nodes), communities and modularity(degree of connections within and between communities) are topics of concern. In this section, these topics are applied and discussed.

2.2.1 Imshow the adjacency matrix¶

In [95]:
# recall the unnormalized data: gene_data
gene_expr_22 = gene_data[gene_data.columns[:-1]].astype(float).to_numpy()
#connected correlations
cor_mat = np.corrcoef(gene_expr_22, rowvar=False)
# adjacency matrix
np.fill_diagonal(cor_mat, 0)
cor_mat[np.abs(cor_mat) < 0.75] = 0
A = cor_mat
In [96]:
print(gene_expr_22.shape)
print(cor_mat.shape)
(800, 95)
(95, 95)
In [97]:
plt.imshow(A)
plt.colorbar();

2.2.2 centrality: top 5 representative nodes¶

In [98]:
# degree centrality
degree = A.sum(axis=1)
sorted_index = np.argsort(degree)[::-1]
sorted_gene_expr_22 = gene_expr_22[sorted_index]
In [99]:
print("The index of the five topping ranking genes are: ", sorted_index[0:5])
print("The top 5 centralties are: ", degree[sorted_index[0:5]])
print("The top 5 ranking gene expressions are:")
for i in range(5):
    print(gene_expr_22[i])
The index of the five topping ranking genes are:  [17 41 16 81 90]
The top 5 centralties are:  [19.73839866 19.04945567 18.64028395 18.53179044 18.2670517 ]
The top 5 ranking gene expressions are:
[ 9.79608829  0.59187087  0.59187087  0.         11.42057082 13.45375934
  4.41184652  5.41233442 10.77161327 10.22566536 10.03868584  5.51190126
  5.77501102 10.92286682  5.6050409   6.05361315  8.40630281  7.72084618
  5.74803716  7.47570912  7.15991169  5.39504909  2.47622613  3.92603738
  1.01027857 13.83498451 13.87784967 13.7711593  10.67108988  0.
  0.          9.52826851  8.82942148  7.82453875 12.21663674  9.84065835
  5.01986645  5.90279957 10.1099351   5.93902906  5.99745726  5.63479653
  7.3927639   4.26735601  2.47622613  7.67861422  4.97334034  5.04224159
  3.26629182  0.59187087 12.2261382  10.91227561 11.36228906 11.84883025
 12.08169979 10.39063115 14.19514917  8.83903683  3.01795753  6.29636379
  5.16967652 17.17356979 18.52516138 14.12612366  6.7207436   7.27775234
 10.6862729   0.          0.          0.          0.          0.
  4.69693415  8.83951875 11.44045394  0.          0.          5.58992824
  9.99201778  0.          4.80112243  6.89684097  6.89684097  2.01539052
  8.57886703  9.1671095   5.9743687   8.08651311 12.72775032 15.2057169
  6.4381165   6.41257662  0.          6.81472985 13.61814457]
[10.07046983  0.          0.          0.         13.08567162 14.53186268
 10.46229777  9.8329264  13.52031174 13.96804574 13.79901848  8.260228
  9.65216737  9.09621968  8.18865313  9.70963528 11.96387506 11.28256722
  9.51009952 10.51952814 11.39463418  9.56670291  8.87603187  1.32716997
  0.58784501 14.76843764 14.23302043 14.57201329 11.66463839  0.
  0.          7.29642987  0.          0.         11.92025608  0.
  9.82974017 10.07282933  0.          0.58784501  0.          9.46716738
  8.88319924  8.42131588  8.22548703 11.04734207  8.51354601  7.9453736
  6.35739365  4.41775127  8.92761798  5.92151726 14.18870474 13.40419691
 15.60740834 14.78221504 14.77833898  0.          8.78683156  9.78190226
 10.4486016   0.          0.          0.81114217  0.         10.29676711
 13.20786846  0.          0.          0.          0.          0.
  6.47383236 12.64973894 12.02801309  0.32365829  0.32365829  8.29183375
  7.79663669  0.32365829  9.40429432 10.54647008  0.          7.42867841
  8.12118122  9.1224346   0.          0.         11.1972044  12.99393259
 10.80074624 10.74981078  0.         11.44560981  0.        ]
[ 8.97091978  0.          0.45259543  0.          8.26311894  9.75490754
  8.96454881  9.94811313  8.69377268  8.77611057  8.76759852  4.00972338
  6.78232993  8.10219094  5.33731148  6.87130155  7.54758093  7.39801682
  5.87905631  9.87243326 10.32239013  8.854005    8.46767483  3.75560333
  2.53366301 11.63681491 10.90115351 11.42809318  8.61324801  0.
  0.          6.29248177  8.80492477  8.40945034  6.16492301 10.48552746
  9.22034447  5.94094193  9.80390508  0.          0.          6.08102852
  6.13073529  7.75868294  6.58603511  6.61344178  4.23873337  4.70363231
  2.6226727   2.53366301  7.05841391  4.50723013  8.70017612  7.95997876
 10.36049597 11.0161397  13.34464267  9.93078759  8.68604874  8.82616393
  8.85973666 14.8184224  16.05359715 10.80903655  6.8758292   6.37165643
  9.71371643  0.          0.          0.          0.          0.
  1.96484219 10.2450535   6.22301316  4.04235542  3.45276673  4.72389162
  7.1435975   0.79659775  9.45885361  6.40967266  8.4974008   2.22801825
  6.89829302  8.93189576  3.90715987  5.32410132 11.48706624 13.38059635
  6.65623607 10.20973359  0.          7.74883018 12.75997554]
[ 8.52461615  1.03941918  0.43488172  0.         10.7985204  12.26301973
  7.44069479  8.06234301  8.80208333  9.23748724  9.35917193  5.80423946
  5.10517519  8.00027048  4.07300613  5.73921895  7.84116804  7.28668712
  5.74868932  7.61131995  8.13093602  7.41416091  6.17115878  4.13209148
  2.80330971 13.46229777 13.20099973 13.24938608 10.39924588  0.
  0.          6.56876026 10.07297667  8.76872689  9.41674054 10.07108689
  7.47532812  5.9672424   9.28544846  0.76858664  0.43488172  5.72008963
  6.86068972  5.5571803   5.32368351  6.63153214  3.89136086  4.98211286
  2.80330971  2.47853181 10.60907714  7.78244133 10.21082961  9.91634527
 12.92707792  9.6801151  12.92714461 10.28753925  6.56876026  7.13967413
  8.12367654 17.37107895 18.37179366 13.77467383  5.52447433  6.60570915
  9.71789915  0.          0.          0.          0.          0.
  4.1430505  11.03974158  9.90527267  0.          0.76858664  3.18447074
  8.04898812  1.26735601  6.88442698  6.07656551  7.99232092  4.13209148
  7.55305325  8.96062754  4.29608292  6.95974698 12.97463865 14.89181218
  6.03072451  7.31564774  0.43488172  7.11792356 12.35327642]
[ 8.04723845  0.          0.          0.36098224 12.2830102  14.03375851
  8.71918002  8.83147193  8.46207277  8.21120206  8.23725777  8.67169674
  4.71700735  9.85755963  2.85877675  4.34557379  8.02536661  7.11465853
  5.5386499   9.12604104  8.10498132  8.41117623  6.82021726  4.99935515
  5.27829333 17.30907662 16.93580962 16.64340528 13.74022347  0.
  0.          8.21258879  0.          0.         13.60890163  0.
  8.14464805  5.40951992  0.          5.16854176  4.26253839  5.64082045
  8.24790378  6.0589009   5.04975273  7.23387848  5.54288019  5.5386499
  3.95400084  1.09565442  8.99392339  6.10729451 10.37130837  9.45944391
 12.18371911 10.86885341 13.04320675  0.          6.79828386  8.69167374
  8.62882826  1.58009723  0.          5.93701005  2.96762975  6.57128545
 10.0700941   0.          0.          0.          0.          0.
  5.48910616  9.27257613 14.03528024  0.64938553  0.          4.1046302
  8.88363303  1.94212035  7.28999119  6.21611142  1.71114245  7.1839926
  6.0215086   7.3386116   0.          0.         11.33723721 13.39006145
  5.98959318  8.35967051  0.          6.32754546  0.        ]

2.2.3 subgraphs of the network: number of 0 eigenvalues¶

In [100]:
# symmetric normalized Laplacian function
def compute_l_norm(A):
    """
    Return the normalized Laplacian.
    Arg:
        A: Adjacency matrix"""

    weighted_degree = A.sum(axis=1)
    D = np.diag(weighted_degree)  # degree matrix D
    # L_norm
    weighted_degree_sqrt = 1.0 / np.sqrt(weighted_degree)
    D_inv_sqrt = np.diag(weighted_degree_sqrt)
    L_norm = np.eye(A.shape[0]) - D_inv_sqrt.dot(A.dot(D_inv_sqrt))

    return L_norm
In [101]:
# eigen decomposition
L_norm = compute_l_norm(A)
eigenvals, eigenvecs = np.linalg.eigh(L_norm)
In [102]:
eigenvals
Out[102]:
array([-1.25666763e-15, -1.22009119e-15, -5.86259751e-16, -2.82855538e-16,
       -2.76761826e-16, -2.26146964e-16, -1.95100938e-16, -1.49770206e-16,
       -9.98842225e-17,  4.49123519e-18,  6.31547765e-17,  2.77555756e-16,
        3.24301287e-16,  3.33281305e-16,  3.96520475e-16,  6.14434034e-16,
        7.29563077e-16,  1.13662133e-15,  4.40674341e-01,  7.09824816e-01,
        8.09205368e-01,  8.74415544e-01,  9.49448789e-01,  9.93084679e-01,
        1.00281813e+00,  1.01076222e+00,  1.02782108e+00,  1.04465433e+00,
        1.04783531e+00,  1.04956517e+00,  1.05073999e+00,  1.05146914e+00,
        1.05162853e+00,  1.05387651e+00,  1.05532375e+00,  1.05570299e+00,
        1.05633592e+00,  1.05665956e+00,  1.05968737e+00,  1.07419326e+00,
        1.07444423e+00,  1.07577484e+00,  1.07772569e+00,  1.07840825e+00,
        1.07930189e+00,  1.07966193e+00,  1.08036281e+00,  1.08130455e+00,
        1.08252594e+00,  1.08512021e+00,  1.08600827e+00,  1.09526489e+00,
        1.11444086e+00,  1.12087585e+00,  1.13008862e+00,  1.13374849e+00,
        1.13484766e+00,  1.13769825e+00,  1.14151998e+00,  1.15353425e+00,
        1.16016346e+00,  1.16477260e+00,  1.16726495e+00,  1.16883048e+00,
        1.16939703e+00,  1.16957148e+00,  1.18065418e+00,  1.18908346e+00,
        1.19607574e+00,  1.19941793e+00,  1.21379412e+00,  1.21536563e+00,
        1.21748346e+00,  1.21859141e+00,  1.22229890e+00,  1.22473975e+00,
        1.25620720e+00,  1.26225785e+00,  1.26294794e+00,  1.26405150e+00,
        1.27433950e+00,  1.45496166e+00,  1.49789651e+00,  1.50210349e+00,
        1.57934555e+00,  2.00000000e+00,  2.00000000e+00,  2.00000000e+00,
        2.00000000e+00,  2.00000000e+00,  2.00000000e+00,  2.00000000e+00,
        2.00000000e+00,  2.00000000e+00,  2.00000000e+00])
In [103]:
# Plotting the spectrum
r = 0
fig, ax = plt.subplots(1)
plt.plot(eigenvals)
for i in range(len(eigenvals)):
    if abs(eigenvals[i]) < 1e-14:
        plt.scatter(i, eigenvals[i], color="red")
        r += 1
    else:
        plt.scatter(i, eigenvals[i], color="blue")

plt.scatter(0, eigenvals[0], color="red", label="rounded zero eigenvalues")
plt.scatter(len(eigenvals), eigenvals[-1], color="blue", label="non-zero eigenvalues")

# gap
plt.axvline(r-0.5, color="red", linestyle="--", label="gap=1e-14")
plt.xlabel("index")
plt.ylabel("Eigenvalues")
plt.title("Spectrum of eigenvalues.")
plt.legend()
plt.grid()
plt.show()

print("The number of zero eigenvalues is: r=", r)
The number of zero eigenvalues is: r= 18

Analysis:¶

Set the threshold as 1e-14, and round down eigenvalues smaller than it to 0. By doing so, it's obtained 18 eigenvalues are rounded down to 0 and marked red in the plot.

By lecture notes, the number of zero eigenvalues correspond to the components of the graph. Therefore, $r=18$ tells that there are 18 components consisting of the graph. In reality, it's possible that there is/are edges connecting components, but quite few.

2.2.4 Elbow method: the optimal k for clustering¶

In this section, when finding the optimal k using elbow method, a polynomial of order 8 is fitted to the within-cluster distance:

  • The loss curve is not smooth, thus determining k directly from the graph will lead to an inaccurate result.

  • The order of the polynomial is adjusted through trials to give an appropriate fit to the loss but without overfitting the data.

The optimal value of k is determined at the largest k where the gradient of the loss is less than -1:

  • We aim to find a point where the absolute gradient to the loss curve is decreasing to zero and, at the same time, the value of k is small. So value of -1 is chosen so that gradients after this point give absolute loss change less than 1.
In [104]:
# U
U = eigenvecs[:, :r]
# constructing T
row_norms = np.linalg.norm(U, axis=1)
D = np.diag(1/row_norms)
T = D.dot(U)
In [105]:
# define elbow function
def norm_within_cluster_dis(X, cluster_labels):
    """
    Return the w_c cost.
    Args:
        X: data set
        cluster_labels: updated cluster labels
        k: number of clusters.
    """
    w_c = 0
    for i in np.unique(cluster_labels):
        # extract the corresponding elmts
        cluster_elemts = X[cluster_labels==i, :]
        for j in cluster_elemts:
            for k in cluster_elemts:
                w_c += 0.5* np.linalg.norm(j-k)**2 / len(cluster_elemts)
    return w_c
In [106]:
# recall the k-means clustering algorithm
np.random.seed(42)
k_range = range(2, 40)
w_c = []
times = 100

# implement clustering with each k for 100 times and choose the one with minimal cost
for k in k_range:
    holding_l = []
    for i in range(times):
        up_centroids, up_labels = k_clustering(T, k, max_iter=100, message=False)
        holding_l.append(norm_within_cluster_dis(T, up_labels))
    wc_cost = np.min(holding_l)
    w_c.append(wc_cost)
In [107]:
# the within-cluster distance cost
print(w_c)
[64.61111111110448, 51.07142857143121, 42.851063829786696, 36.34999999999901, 33.05128205128229, 26.324110671936715, 23.780219780219834, 21.303030303030248, 17.76923076923074, 15.000000000000021, 13.111111111111118, 11.399999999999991, 8.857142857142863, 6.400000000000001, 4.4, 4.0, 3.9999999999999982, 3.1111111111111125, 1.9999999999999991, 2.0, 9.4374449716689e-29, 1.9999999999999996, 9.335365708757868e-29, 8.914833625882453e-29, 8.627403593411206e-29, 8.732139564492352e-29, 7.975464815141715e-29, 8.544480761022826e-29, 7.987259597634385e-29, 7.971117744695417e-29, 8.244491999274519e-29, 7.560213599202927e-29, 7.12001438346075e-29, 7.246370752697711e-29, 6.970868871646457e-29, 6.955285149371518e-29, 6.244727734560932e-29, 6.136287568266725e-29]
elbow method: seeting an epsilon¶

For elbow method, we want to tradeoff between k and the cost-- a k such that the cost is low while the k is small as well.

From lecture notes, the elbow curve is deceasing fast and then the decreasing rate slows down and finally remains almost stable. The elbow point is set at the 'middle' change-phase ie. the region where the curve's decreasing rate is slowing down.

Therefore, a tolerance eps should be set and adjusted through trials. (Intuitively from the plot, we should expect that it falls between 12-20.)

In [108]:
eps = 2  # tolerance
for i in range(len(w_c)):
    if w_c[i] <= eps:
        elbow_k = i + 2
        print("The elbow k is: ", elbow_k)
        print("The cost is: ", w_c[i])
        break
The elbow k is:  20
The cost is:  1.9999999999999991
In [109]:
plt.plot(k_range, w_c)
plt.scatter(k_range, w_c, label='Distance Values')
plt.scatter(elbow_k, w_c[elbow_k-2], color="red", label="the elbow k")
plt.legend()
plt.xlabel("number of clusters k")
plt.ylabel("within cluster distance")
plt.title("The within cluster distance against cluster number k")
plt.show()
In [110]:
# obtain clustering for the optimal k
np.random.seed(4)
centroids, labels = k_clustering(T, elbow_k, max_iter=100, message=True)

# construct a dictionary to hold the indexes of points for each label
cluster_dict = defaultdict(list)
for i in range(elbow_k):
    cluster_dict[i].extend(list(np.where(labels==i)[0]))

print("The value of k at the elbow point is: ", elbow_k)
print("The clustering by elbow k: ", cluster_dict)
Iteration: 0
72.631579% labels changed
Iteration: 1
Labels unchanged! Terminating k-means.
The value of k at the elbow point is:  20
The clustering by elbow k:  defaultdict(<class 'list'>, {0: [], 1: [84, 85], 2: [49, 73], 3: [29, 30, 50, 51], 4: [6, 7, 19, 20, 21, 22, 36, 43, 44, 55, 58, 59, 60, 77, 80, 91], 5: [8, 9, 10, 12, 14, 15, 16, 17, 18, 37, 41, 45, 46, 47, 48, 52, 53, 54, 65, 66, 81, 90, 93], 6: [1, 2, 3, 23, 24, 39, 40, 79, 92], 7: [0, 56, 75, 76], 8: [], 9: [25, 26, 27, 28, 34, 74], 10: [], 11: [42, 61, 62, 63, 64, 72, 86, 87], 12: [4, 5], 13: [], 14: [11, 13, 31, 78, 83], 15: [67, 68, 69, 70, 71], 16: [], 17: [32, 33, 35, 38, 57, 82, 94], 18: [88, 89], 19: []})
In [111]:
# define a function to catch the size of the clusters of T
def cluster_size(cluster_dict):
    """
    Return the sizes of each cluster of T.
    """
    size_dict = defaultdict(int)
    for i in cluster_dict:
        size_dict[i] = len(cluster_dict[i])

    return size_dict

size_dict = cluster_size(cluster_dict)
print("The size of each cluster by elbow k is: ", size_dict)
The size of each cluster by elbow k is:  defaultdict(<class 'int'>, {0: 0, 1: 2, 2: 2, 3: 4, 4: 16, 5: 23, 6: 9, 7: 4, 8: 0, 9: 6, 10: 0, 11: 8, 12: 2, 13: 0, 14: 5, 15: 5, 16: 0, 17: 7, 18: 2, 19: 0})

Interpretation to the graph structure:¶

  • The appropriateness of elbow k:

By seeting a tolerance eps, elbow k is found at k==20. Each column of T is an eigenvector corresponds to a unique subgraph of the original network. Therefore, clustering T should give an optimal k close to r=18. The k found by elbow problem makes sense from this perspective.

  • construction & clustering of T:

The construction of T is equivalent to projecting the nodes into a lower dimensional space (dim=18 in this case)

Similar to clustering implemented in coursework 1, clusters found have smallest total within-cluster distance and greatest total inter-cluster diatances. And these two criteria are substituted by connectivity: as many as within-cluster edges and as less as inter-cluster edges. Implementing k-means clustering, we found that some of the labels have zero samples clustered in that cluster, and this is because the update of labels with the change of centroids in each iteration of k-means algorithm.

The clustering of T provides an insight into the modularity(how clusters connect) of the network. For T:

$\cdot$ each row represents the behavior of a node in these zero eigen-spaces. By clustering results, there are 14 clusters of similar behavior of the nodes, where each cluster is densely connedted within itself and sparsely connected with other clusters. This gives an insight of which of the features of the genes are more connected and which are not.

  • clustering by T & by 2.1.3:

By 2.1.3, there are 18 clusters. But for the result of clustering of T, there are 14 non-empty clusters and 6 empty clusters. This means that some of clusters must have been merged to form a bigger one in kmeans clustering. Recall the process of updating centroids of each cluster and reassigning labels--bigger clusters tend to drag small clusters and merge them as they are a main force of dragging the centoids. In terms of the graph structure, some small clusters are more connected to bigger ones, eg. the cluster labelled 1 can be more likely to be more attached to the cluster labelled 4.

2.2.5: Spectral partitioning and binary partition of the biggest subgraph¶

In [112]:
# get the largest cluster and corresponding node indices
largest_cluster = list(size_dict.keys())[np.argmax(list(size_dict.values()))]
largest_cluster_indices = cluster_dict[largest_cluster]
# get subgraph and corresponding adjcency matrix
mask = (labels == largest_cluster)
sub_A = A[mask,:][:,mask]
# corresponding indexes of the original network
mask_indx = np.where(mask==True)

# get sub_Lagrangian
sub_L_norm = compute_l_norm(sub_A)
sub_eigenvals, sub_eigenvecs = np.linalg.eigh(sub_L_norm)
# perform binary spectral partition
second_index = np.where(sub_eigenvals>1e-7)[0][0] # first nontrivial index
spectral_partition = sub_eigenvecs[:,second_index]
spectral_partition[spectral_partition<0] = 0 
spectral_partition[spectral_partition>0] = 1 
In [113]:
p1 = np.where(spectral_partition == 0)[0]
p2 = np.where(spectral_partition == 1)[0]
n1 = len(p1)
n2 = len(p2)

# Define the ratios for the subplots
width_ratios = [1, 1, 1.8]
# Create the figure and the grid
fig = plt.figure(figsize=(25, 15))
gs = gridspec.GridSpec(1, 3, width_ratios=width_ratios)

# first plot
ax0 = plt.subplot(gs[0])
im0 = ax0.imshow(sub_A[p1,:][:,p1], cmap='viridis', vmin=-1, vmax=1)
ax0.set_xticks(np.arange(n1))
ax0.set_yticks(np.arange(n1))
ax0.set_xticklabels(np.array(largest_cluster_indices)[p1])
ax0.set_yticklabels(np.array(largest_cluster_indices)[p1])
ax0.set_title('Magnitude of Network Links within Partition 1')

# second plot
ax1 = plt.subplot(gs[1])
im1 = ax1.imshow(sub_A[p2,:][:,p2], cmap='viridis', vmin=-1, vmax=1)
ax1.set_xticks(np.arange(n2))
ax1.set_yticks(np.arange(n2))
ax1.set_xticklabels(np.array(largest_cluster_indices)[p2])
ax1.set_yticklabels(np.array(largest_cluster_indices)[p2])
ax1.set_title('Magnitude of Network Links within Partition 2')

# third plot
ax2 = plt.subplot(gs[2])
im2 = ax2.imshow(sub_A[p1,:][:,p2], cmap='viridis', vmin=-1, vmax=1)
ax2.set_xticks(np.arange(n2))
ax2.set_yticks(np.arange(n1))
ax2.set_xticklabels(np.array(largest_cluster_indices)[p2])
ax2.set_yticklabels(np.array(largest_cluster_indices)[p1])
ax2.set_title('Magnitude of the Network Links across Partitions')

# Add a color bar
cbar2 = fig.colorbar(im2, ax=ax2, shrink=0.4)
cbar2.set_label('Magnitude of Lines')

# Show the plot
plt.show()

Comment on the pattern:¶

A corresponding network graph is plotted below.

within partitions:

  • partition 1: any node is connected with other nodes with a single edge except the pair (37, 14) and (93, 8)
  • partition 2: all nodes are connected with a single edge
  • Comparing two individual partitions, partition 2 is more closely connected, which can be confirmed with the partial graph below connected in blue.

across partitions:

  • not all the nodes are directly connected with a single edge. Particularly, node 14 in partition 1 is sparsely connected with all nodes in partition 2 with a single edge, which corresponds to the farthest red node in the network graph below.
  • node 12 and 14 in partition 1 are also sparsely connected to partition 2, with a few connected edges. They correspond to the red nodes at the 'boundaries' of 2 partitions. Similarly, node 46, 47 and 48 correspond to the blue nodes at the partition boundary, as they are only sinply connected to a few nodes in partition 1.
  • node 37 in partition 1 is sparsely connected with node 48 in partition 2
compliment to the analysis¶
In [114]:
# drawing networkx graph of sub_A, and coloring partitions with 2 colors
g = nx.Graph(sub_A)
colored_nodes = p1
node_colors = ['red' if node in colored_nodes else 'blue' for node in g.nodes()]
nx.draw(g, node_size=30, node_color=node_colors)

2.2.6 Centralities of the subgraph¶

In [115]:
# degree centrality
sub_degree = sub_A.sum(axis=1)
sub_sorted_index = np.argsort(sub_degree)[::-1]
sub_sorted_gene_expr_22 = gene_expr_22[sub_sorted_index]
print("The index of the five topping ranking genes are: ", mask_indx[0][sub_sorted_index[0:5]])
print("The top 5 centralities are: ", sub_degree[sub_sorted_index[0:5]])
print("The top 5 ranking gene expressions are:",)
for i in range(5):
    print(sub_sorted_gene_expr_22[i])
The index of the five topping ranking genes are:  [17 41 16 81 90]
The top 5 centralities are:  [19.73839866 19.04945567 18.64028395 18.53179044 18.2670517 ]
The top 5 ranking gene expressions are:
[ 8.92200751  1.65526023  0.44180215  0.         12.06016901 13.67437988
  4.39674177  4.42107515 11.38857703 10.95483607 10.92828889  8.57963273
  7.98435579 11.33975538  7.18736197  8.85325653 10.24577848  9.43419452
  7.63761749  5.20278991  5.1010365   5.31091447  2.96487914  2.89710486
  1.48104056 15.29502989 16.03244183 14.99020592 12.31127434  0.
  0.          9.66501785  7.91964995  6.48955899 15.18833385 10.71144328
  3.41602901  8.27041214  9.32156512  6.09506783  4.76317252  8.31978049
  8.32314035  4.62336356  3.63985727  8.95049758  4.28954678  4.89062179
  3.0296119   0.7795539   9.9098121   7.53265203 12.89064048 12.39811037
 13.96305862  9.58934029 13.53431504  9.87500439  3.0296119   4.7820835
  4.53694187 15.74874843 17.60456047 14.13791172  7.20498189  8.82161775
 11.65521317  0.          0.          0.          0.          0.
  5.19462292 10.23099312 15.05184618  0.          0.          3.55444151
 10.49176293  1.9509909   4.51450104  8.16629868  8.70011022  6.15432599
  5.75562469  7.39119203  3.31576909  5.94372914 11.72023571 13.80957723
  7.48439596  4.26850925  0.          9.43867128 12.81238792]
[ 8.11540786  0.          0.          0.         11.44178743 13.13296716
  8.05858708  8.29733359 10.5051069  10.32072159 10.13971761  6.79013342
  9.19840821 10.11150086  7.84939885  9.59669107 10.39006145  9.82915823
  7.10945459  9.20889031  8.13752402  8.50382971  5.93018144  1.83608571
  1.65168318 16.53078428 16.06013438 16.53424512 12.50169788  0.
  0.          8.97098016  0.          0.         14.05376471  0.
  7.47864021  8.73172632  0.          5.93018144  4.87157755  9.33482523
  8.79097073  4.9335632   4.51403386  8.31102168  5.84672648  6.10198197
  4.40156221  5.76951787  8.75050267  5.9994431  12.83481865 12.30149048
 13.76738789 10.97626338 12.64077816  0.          6.61835026  8.49529143
  8.08841725  2.61842662  3.12288803  4.61836788  0.51440004  8.40153661
 11.55723871  0.          0.          0.          0.          0.
  5.89009887 14.34901163 13.52010299  0.          0.          4.71566278
 10.18607709  3.19214681  7.84671895  8.71711431  0.          4.31105257
  6.12872644  7.40423045  0.          0.         11.37563681 13.30660288
  8.35132355  8.02101326  0.         10.14791719  0.        ]
[ 9.97363997  6.8802937   5.68310691  0.         11.61823407 13.9518584
  5.29798006  4.20064559 11.16863451 10.96875318 10.40217003  8.33689033
  6.85273548 11.94623374  5.42318148  7.08838545  9.4623264   8.47163457
  6.70643385  6.81895279  5.91398082  5.85963169  3.70458428  7.30970371
  7.01568266 15.17014023 15.5444936  15.36038417 12.44458529  0.
  0.         10.70656509  6.76674091  7.51841711 12.76486745 10.61807344
  5.29798006  5.85963169  8.34877683  9.32666767 10.64202461  7.65552357
  8.95335837  4.79441067  3.17392693  8.57307151  6.70025937  7.15308584
  5.75632681  2.76464339  9.99585193  8.44604527 12.11841878 11.52265509
 12.32970666  7.80689757 14.71886461  9.48404546  4.56909002  6.63056419
  6.16766301  0.          0.          0.          0.          8.49856203
 11.78353032  0.          0.          0.          0.          0.
  6.8886215  10.39599461 11.78316605  0.          0.          2.32553033
 10.94708897  4.7710457   3.75310173  6.96510317  7.50785805  5.56579517
  6.59107947  8.02012434  0.          0.         11.16275007 13.48954695
  7.12571309  5.84851278  0.          7.16657486 12.34301356]
[10.10335338  0.          0.          0.         12.02755703 13.95778223
 10.62569058  9.96159873  8.42283848  9.34314573  9.2134931   6.2207962
  5.24691347 10.15564046  2.7654711   5.1618232   8.9703619   9.0524322
  7.54093394  8.64549993  8.77051524  8.8129923   8.77171064  0.78659636
  0.         16.93219168 17.2118978  17.00718061 13.89325412  0.
  0.          8.2290495   0.          0.         14.07698257  0.
 10.85953479  6.89686518  0.          5.0399499   3.98748459  8.63895715
  7.07066839  7.97843019  6.49458842  9.91143693  6.89168651  6.37366909
  4.95819773  1.82268936  4.97492389  1.96343713 11.14972172  9.2405457
 12.5257141   9.55806861 14.49403678  0.          9.02458841  8.80949805
  9.46802118  0.44625623  0.          6.05765388  2.31825899  8.31202071
 12.08188616  0.          0.          0.          0.          0.
  3.53908435 10.11297374 15.3735483   0.          0.          0.78659636
  8.43648669  1.06170708  8.72303971  8.4678419   0.78659636  3.77472382
  4.20157113  5.43115157  0.          0.44625623 11.30287878 13.46658634
  7.78807881  9.16694386  0.          8.52432623  0.        ]
[ 8.69782594  3.51765364  3.86295716  8.78777199 12.33524239 13.99592611
  8.04792549  9.37627519 12.82544579 13.02612398 12.83938293  8.57285557
  7.80233541  9.14525457  6.80379532  8.36665006 11.64664421 10.41460068
  8.88902628  8.77635407  8.38179816  7.94515702  7.02577078  5.50184674
  4.4013301  14.44136209 14.23803894 14.70011997 12.05164739  0.
  0.          7.64660191  8.79533513  6.26455144 12.59471415 10.78215462
  7.58974851  9.58193304 10.32730562  2.1648828   2.1648828   8.85565651
 10.49858996  6.1033025   5.4585559  10.04753294  8.47262992  8.23195075
  6.58853625  3.41666412  5.75593174  3.06277796 12.74262944 12.09926232
 14.42464527 11.2576648  13.47460589  9.04519721  6.51715441  7.88263082
  8.74545248  1.11196609  0.          2.39467876  0.4720718  10.01820018
 13.22198278  0.          0.          0.          0.          0.
  7.88028758 11.61066961 12.96694158  1.11196609  0.          4.75161368
  8.47107363  4.2627263   6.62294499  9.8383231   8.14385766  7.10093561
  6.51104418  7.85819241  0.          0.         11.85066349 13.89602014
  9.93506631  8.88075592  4.34747432 10.40391742 14.74470772]
In [116]:
print("The index of the top 5 ranking genes are: ", sorted_index[0:5])
print("The centralities of top 5 gene expressions of the A are:", degree[sorted_index[0:5]])
print("The index of the top 5 ranking genes are: ", mask_indx[0][sub_sorted_index[0:5]])
print("The centralities of top 5 gene expressions of the sub_A are:", sub_degree[sub_sorted_index[0:5]])
The index of the top 5 ranking genes are:  [17 41 16 81 90]
The centralities of top 5 gene expressions of the A are: [19.73839866 19.04945567 18.64028395 18.53179044 18.2670517 ]
The index of the top 5 ranking genes are:  [17 41 16 81 90]
The centralities of top 5 gene expressions of the sub_A are: [19.73839866 19.04945567 18.64028395 18.53179044 18.2670517 ]

Analysis:¶

17 41 16 81 90
centrality of original graph 19.73839866 19.04945567 18.64028395 18.53179044 18.2670517
centrality of the sub graph 19.73839866 19.04945567 18.64028395 18.53179044 18.2670517

By the printed result above, the top 5 degree centrality of the sub graph remain identical to the orignal network.

This makes sense by the way we found the elbow k, and the biggest cluster:

  • The elbow k is found by the minimal cost of 100 iterations of each k and among a range of k, guaranteed the 'best' clustering of the graph as possibly. After this step, nodes with more edges connected to each other are more likely to be clustered as a subgraph.
  • The subgraph is constructed from the biggest cluster from the clusters found by the optimal k, which is, at the highest possibility, the most densely knitted subgraph of the network.

The centrality of a network indicates the importance of each node: the higher the centrality, the more representative of a node. Therefore, the top 5 centralities of the subgraph is expected to be identical to be the global top 5 centralities of the original network.

That's all. Thanks for your reading!¶

Task 3: Mastery component (25 marks)¶

Please delete this section if you are not a master student